[
  {
    "path": ".github/workflows/build.yml",
    "content": "name: build\non: [push, pull_request]\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ruby/setup-ruby@v1\n        with:\n          ruby-version: \"4.0\"\n          bundler-cache: true\n      - uses: actions/cache@v5\n        with:\n          path: ~/.cache/informers\n          key: informers\n      - run: sudo apt-get update && sudo apt-get install libvips\n      - run: bundle exec rake download:files\n      - run: bundle exec rake test\n"
  },
  {
    "path": ".gitignore",
    "content": "/.bundle/\n/.yardoc\n/_yardoc/\n/coverage/\n/doc/\n/pkg/\n/spec/reports/\n/test/support/\n/tmp/\n*.lock\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "## 1.3.0 (unreleased)\n\n- Dropped support for Ruby < 3.3\n\n## 1.2.1 (2025-02-01)\n\n- Fixed error when terminal width is zero\n\n## 1.2.0 (2024-11-14)\n\n- Added support for models with external data\n- Added `device` option\n- Added `dtype` option\n- Added `session_options` option\n\n## 1.1.1 (2024-10-14)\n\n- Added `audio-classification` pipeline\n- Fixed error with `sentence-transformers/all-MiniLM-L6-v2`\n\n## 1.1.0 (2024-09-17)\n\n- Added more pipelines\n\n## 1.0.3 (2024-08-29)\n\n- Added `model_output` option\n- Improved `model_file_name` option\n\n## 1.0.2 (2024-08-28)\n\n- Added `embedding` pipeline\n- Added experimental `reranking` pipeline\n- Added support for `nomic-ai/nomic-embed-text-v1`\n\n## 1.0.1 (2024-08-27)\n\n- Added support for `Supabase/gte-small` to `Model`\n- Fixed error with downloads\n\n## 1.0.0 (2024-08-26)\n\n- Replaced task classes with `pipeline` method\n- Added `Model` class\n- Dropped support for Ruby < 3.1\n\n## 0.2.0 (2022-09-06)\n\n- Added support for `optimum` and `transformers.onnx` models\n- Dropped support for Ruby < 2.7\n\n## 0.1.3 (2021-09-25)\n\n- Added text generation\n- Added fill mask\n\n## 0.1.2 (2020-11-24)\n\n- Added feature extraction\n\n## 0.1.1 (2020-10-05)\n\n- Fixed question answering for Ruby < 2.7\n\n## 0.1.0 (2020-10-01)\n\n- First release\n"
  },
  {
    "path": "Gemfile",
    "content": "source \"https://rubygems.org\"\n\ngemspec\n\ngem \"rake\"\ngem \"minitest\"\ngem \"ruby-vips\"\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Informers\n\n:fire: Fast [transformer](https://github.com/huggingface/transformers.js) inference for Ruby\n\nFor non-ONNX models, check out [Transformers.rb](https://github.com/ankane/transformers-ruby) :slightly_smiling_face:\n\n[![Build Status](https://github.com/ankane/informers/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/informers/actions)\n\n## Installation\n\nAdd this line to your application’s Gemfile:\n\n```ruby\ngem \"informers\"\n```\n\n## Getting Started\n\n- [Models](#models)\n- [Pipelines](#pipelines)\n\n## Models\n\nEmbedding\n\n- [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)\n- [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)\n- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)\n- [sentence-transformers/paraphrase-MiniLM-L6-v2](#sentence-transformersparaphrase-minilm-l6-v2)\n- [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)\n- [Supabase/gte-small](#supabasegte-small)\n- [intfloat/e5-base-v2](#intfloate5-base-v2)\n- [nomic-ai/nomic-embed-text-v1](#nomic-ainomic-embed-text-v1)\n- [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)\n- [jinaai/jina-embeddings-v2-base-en](#jinaaijina-embeddings-v2-base-en)\n- [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)\n\nReranking\n\n- [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)\n- [jinaai/jina-reranker-v1-turbo-en](#jinaaijina-reranker-v1-turbo-en)\n- [BAAI/bge-reranker-base](#baaibge-reranker-base)\n- [Xenova/ms-marco-MiniLM-L-6-v2](#xenovams-marco-minilm-l-6-v2)\n\n### sentence-transformers/all-MiniLM-L6-v2\n\n[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)\n\n```ruby\nsentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\nmodel = Informers.pipeline(\"embedding\", \"sentence-transformers/all-MiniLM-L6-v2\")\nembeddings = model.(sentences)\n```\n\n### sentence-transformers/multi-qa-MiniLM-L6-cos-v1\n\n[Docs](https://huggingface.co/Xenova/multi-qa-MiniLM-L6-cos-v1)\n\n```ruby\nquery = \"How many people live in London?\"\ndocs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\nmodel = Informers.pipeline(\"embedding\", \"sentence-transformers/multi-qa-MiniLM-L6-cos-v1\")\nquery_embedding = model.(query)\ndoc_embeddings = model.(docs)\nscores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }\ndoc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }\n```\n\n### sentence-transformers/all-mpnet-base-v2\n\n[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)\n\n```ruby\nsentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\nmodel = Informers.pipeline(\"embedding\", \"sentence-transformers/all-mpnet-base-v2\")\nembeddings = model.(sentences)\n```\n\n### sentence-transformers/paraphrase-MiniLM-L6-v2\n\n[Docs](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)\n\n```ruby\nsentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\nmodel = Informers.pipeline(\"embedding\", \"sentence-transformers/paraphrase-MiniLM-L6-v2\")\nembeddings = model.(sentences, normalize: false)\n```\n\n### mixedbread-ai/mxbai-embed-large-v1\n\n[Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)\n\n```ruby\nquery_prefix = \"Represent this sentence for searching relevant passages: \"\n\ninput = [\n  \"The dog is barking\",\n  \"The cat is purring\",\n  query_prefix + \"puppy\"\n]\n\nmodel = Informers.pipeline(\"embedding\", \"mixedbread-ai/mxbai-embed-large-v1\")\nembeddings = model.(input)\n```\n\n### Supabase/gte-small\n\n[Docs](https://huggingface.co/Supabase/gte-small)\n\n```ruby\nsentences = [\"That is a happy person\", \"That is a very happy person\"]\n\nmodel = Informers.pipeline(\"embedding\", \"Supabase/gte-small\")\nembeddings = model.(sentences)\n```\n\n### intfloat/e5-base-v2\n\n[Docs](https://huggingface.co/intfloat/e5-base-v2)\n\n```ruby\ndoc_prefix = \"passage: \"\nquery_prefix = \"query: \"\n\ninput = [\n  doc_prefix + \"Ruby is a programming language created by Matz\",\n  query_prefix + \"Ruby creator\"\n]\n\nmodel = Informers.pipeline(\"embedding\", \"intfloat/e5-base-v2\")\nembeddings = model.(input)\n```\n\n### nomic-ai/nomic-embed-text-v1\n\n[Docs](https://huggingface.co/nomic-ai/nomic-embed-text-v1)\n\n```ruby\ndoc_prefix = \"search_document: \"\nquery_prefix = \"search_query: \"\n\ninput = [\n  doc_prefix + \"The dog is barking\",\n  doc_prefix + \"The cat is purring\",\n  query_prefix + \"puppy\"\n]\n\nmodel = Informers.pipeline(\"embedding\", \"nomic-ai/nomic-embed-text-v1\")\nembeddings = model.(input)\n```\n\n### BAAI/bge-base-en-v1.5\n\n[Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)\n\n```ruby\nquery_prefix = \"Represent this sentence for searching relevant passages: \"\n\ninput = [\n  \"The dog is barking\",\n  \"The cat is purring\",\n  query_prefix + \"puppy\"\n]\n\nmodel = Informers.pipeline(\"embedding\", \"BAAI/bge-base-en-v1.5\")\nembeddings = model.(input)\n```\n\n### jinaai/jina-embeddings-v2-base-en\n\n[Docs](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)\n\n```ruby\nsentences = [\"How is the weather today?\", \"What is the current weather like today?\"]\n\nmodel = Informers.pipeline(\"embedding\", \"jinaai/jina-embeddings-v2-base-en\", model_file_name: \"../model\")\nembeddings = model.(sentences)\n```\n\n### Snowflake/snowflake-arctic-embed-m-v1.5\n\n[Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)\n\n```ruby\nquery_prefix = \"Represent this sentence for searching relevant passages: \"\n\ninput = [\n  \"The dog is barking\",\n  \"The cat is purring\",\n  query_prefix + \"puppy\"\n]\n\nmodel = Informers.pipeline(\"embedding\", \"Snowflake/snowflake-arctic-embed-m-v1.5\")\nembeddings = model.(input, model_output: \"sentence_embedding\", pooling: \"none\")\n```\n\n### mixedbread-ai/mxbai-rerank-base-v1\n\n[Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)\n\n```ruby\nquery = \"How many people live in London?\"\ndocs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\nmodel = Informers.pipeline(\"reranking\", \"mixedbread-ai/mxbai-rerank-base-v1\")\nresult = model.(query, docs)\n```\n\n### jinaai/jina-reranker-v1-turbo-en\n\n[Docs](https://huggingface.co/jinaai/jina-reranker-v1-turbo-en)\n\n```ruby\nquery = \"How many people live in London?\"\ndocs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\nmodel = Informers.pipeline(\"reranking\", \"jinaai/jina-reranker-v1-turbo-en\")\nresult = model.(query, docs)\n```\n\n### BAAI/bge-reranker-base\n\n[Docs](https://huggingface.co/BAAI/bge-reranker-base)\n\n```ruby\nquery = \"How many people live in London?\"\ndocs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\nmodel = Informers.pipeline(\"reranking\", \"BAAI/bge-reranker-base\")\nresult = model.(query, docs)\n```\n\n### Xenova/ms-marco-MiniLM-L-6-v2\n\n[Docs](https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2)\n\n```ruby\nquery = \"How many people live in London?\"\ndocs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\nmodel = Informers.pipeline(\"reranking\", \"Xenova/ms-marco-MiniLM-L-6-v2\")\nresult = model.(query, docs)\n```\n\n### Other\n\nThe model must include a `.onnx` file ([example](https://huggingface.co/Xenova/all-MiniLM-L6-v2/tree/main/onnx)). If the file is not at `onnx/model.onnx`, use the `model_file_name` option to specify the location.\n\n## Pipelines\n\n- [Text](#text)\n- [Vision](#vision)\n- [Audio](#audio)\n- [Multimodel](#multimodal)\n\n### Text\n\nEmbedding\n\n```ruby\nembed = Informers.pipeline(\"embedding\")\nembed.(\"We are very happy to show you the 🤗 Transformers library.\")\n```\n\nReranking\n\n```ruby\nrerank = Informers.pipeline(\"reranking\")\nrerank.(\"Who created Ruby?\", [\"Matz created Ruby\", \"Another doc\"])\n```\n\nNamed-entity recognition\n\n```ruby\nner = Informers.pipeline(\"ner\")\nner.(\"Ruby is a programming language created by Matz\")\n```\n\nSentiment analysis\n\n```ruby\nclassifier = Informers.pipeline(\"sentiment-analysis\")\nclassifier.(\"We are very happy to show you the 🤗 Transformers library.\")\n```\n\nQuestion answering\n\n```ruby\nqa = Informers.pipeline(\"question-answering\")\nqa.(\"Who invented Ruby?\", \"Ruby is a programming language created by Matz\")\n```\n\nZero-shot classification\n\n```ruby\nclassifier = Informers.pipeline(\"zero-shot-classification\")\nclassifier.(\"text\", [\"label1\", \"label2\", \"label3\"])\n```\n\nText generation\n\n```ruby\ngenerator = Informers.pipeline(\"text-generation\")\ngenerator.(\"I enjoy walking with my cute dog,\")\n```\n\nText-to-text generation\n\n```ruby\ntext2text = Informers.pipeline(\"text2text-generation\")\ntext2text.(\"translate from English to French: I'm very happy\")\n```\n\nTranslation\n\n```ruby\ntranslator = Informers.pipeline(\"translation\", \"Xenova/nllb-200-distilled-600M\")\ntranslator.(\"जीवन एक चॉकलेट बॉक्स की तरह है।\", src_lang: \"hin_Deva\", tgt_lang: \"fra_Latn\")\n```\n\nSummarization\n\n```ruby\nsummarizer = Informers.pipeline(\"summarization\")\nsummarizer.(\"Many paragraphs of text\")\n```\n\nFill mask\n\n```ruby\nunmasker = Informers.pipeline(\"fill-mask\")\nunmasker.(\"Paris is the [MASK] of France.\")\n```\n\nFeature extraction\n\n```ruby\nextractor = Informers.pipeline(\"feature-extraction\")\nextractor.(\"We are very happy to show you the 🤗 Transformers library.\")\n```\n\n### Vision\n\nNote: [ruby-vips](https://github.com/libvips/ruby-vips) is required to load images\n\nImage classification\n\n```ruby\nclassifier = Informers.pipeline(\"image-classification\")\nclassifier.(\"image.jpg\")\n```\n\nZero-shot image classification\n\n```ruby\nclassifier = Informers.pipeline(\"zero-shot-image-classification\")\nclassifier.(\"image.jpg\", [\"label1\", \"label2\", \"label3\"])\n```\n\nImage segmentation\n\n```ruby\nsegmenter = Informers.pipeline(\"image-segmentation\")\nsegmenter.(\"image.jpg\")\n```\n\nObject detection\n\n```ruby\ndetector = Informers.pipeline(\"object-detection\")\ndetector.(\"image.jpg\")\n```\n\nZero-shot object detection\n\n```ruby\ndetector = Informers.pipeline(\"zero-shot-object-detection\")\ndetector.(\"image.jpg\", [\"label1\", \"label2\", \"label3\"])\n```\n\nDepth estimation\n\n```ruby\nestimator = Informers.pipeline(\"depth-estimation\")\nestimator.(\"image.jpg\")\n```\n\nImage-to-image\n\n```ruby\nupscaler = Informers.pipeline(\"image-to-image\")\nupscaler.(\"image.jpg\")\n```\n\nImage feature extraction\n\n```ruby\nextractor = Informers.pipeline(\"image-feature-extraction\")\nextractor.(\"image.jpg\")\n```\n\n### Audio\n\nNote: [ffmpeg](https://www.ffmpeg.org/) is required to load audio files\n\nAudio classification\n\n```ruby\nclassifier = Informers.pipeline(\"audio-classification\")\nclassifier.(\"audio.wav\")\n```\n\n### Multimodal\n\nImage captioning\n\n```ruby\ncaptioner = Informers.pipeline(\"image-to-text\")\ncaptioner.(\"image.jpg\")\n```\n\nDocument question answering\n\n```ruby\nqa = Informers.pipeline(\"document-question-answering\")\nqa.(\"image.jpg\", \"What is the invoice number?\")\n```\n\n## Reference\n\nSpecify a variant of the model if available (`fp32`, `fp16`, `int8`, `uint8`, `q8`, `q4`, `q4f16`, or `bnb4`)\n\n```ruby\nInformers.pipeline(\"embedding\", \"Xenova/all-MiniLM-L6-v2\", dtype: \"fp16\")\n```\n\nSpecify a device (`cpu`, `cuda`, or `coreml`)\n\n```ruby\nInformers.pipeline(\"embedding\", device: \"cuda\")\n```\n\nNote: Follow [these instructions](https://github.com/ankane/onnxruntime-ruby?tab=readme-ov-file#gpu-support) for `cuda`\n\nSpecify ONNX Runtime [session options](https://github.com/ankane/onnxruntime-ruby?tab=readme-ov-file#session-options)\n\n```ruby\nInformers.pipeline(\"embedding\", session_options: {log_severity_level: 2})\n```\n\n## Credits\n\nThis library was ported from [Transformers.js](https://github.com/huggingface/transformers.js) and is available under the same license.\n\n## History\n\nView the [changelog](https://github.com/ankane/informers/blob/master/CHANGELOG.md)\n\n## Contributing\n\nEveryone is encouraged to help improve this project. Here are a few ways you can help:\n\n- [Report bugs](https://github.com/ankane/informers/issues)\n- Fix bugs and [submit pull requests](https://github.com/ankane/informers/pulls)\n- Write, clarify, or fix documentation\n- Suggest or add new features\n\nTo get started with development:\n\n```sh\ngit clone https://github.com/ankane/informers.git\ncd informers\nbundle install\nbundle exec rake download:files\nbundle exec rake test\n```\n"
  },
  {
    "path": "Rakefile",
    "content": "require \"bundler/gem_tasks\"\nrequire \"rake/testtask\"\n\nRake::TestTask.new do |t|\n  t.pattern = FileList[\"test/**/*_test.rb\"].exclude(\"test/model_test.rb\")\nend\n\ntask default: :test\n\ndef download_file(url)\n  require \"open-uri\"\n\n  file = File.basename(url)\n  puts \"Downloading #{file}...\"\n  dest = \"test/support/#{file}\"\n  File.binwrite(dest, URI.parse(url).read)\n  puts \"Saved #{dest}\"\nend\n\nnamespace :download do\n  task :files do\n    Dir.mkdir(\"test/support\") unless Dir.exist?(\"test/support\")\n\n    download_file(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\")\n    download_file(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_1.png\")\n  end\nend\n"
  },
  {
    "path": "informers.gemspec",
    "content": "require_relative \"lib/informers/version\"\n\nGem::Specification.new do |spec|\n  spec.name          = \"informers\"\n  spec.version       = Informers::VERSION\n  spec.summary       = \"Fast transformer inference for Ruby\"\n  spec.homepage      = \"https://github.com/ankane/informers\"\n  spec.license       = \"Apache-2.0\"\n\n  spec.author        = \"Andrew Kane\"\n  spec.email         = \"andrew@ankane.org\"\n\n  spec.files         = Dir[\"*.{md,txt}\", \"{lib}/**/*\"]\n  spec.require_path  = \"lib\"\n\n  spec.required_ruby_version = \">= 3.3\"\n\n  spec.add_dependency \"onnxruntime\", \">= 0.9\"\n  spec.add_dependency \"tokenizers\", \">= 0.5.3\"\nend\n"
  },
  {
    "path": "lib/informers/backends/onnx.rb",
    "content": "module Informers\n  module Backends\n    module Onnx\n      def self.device_to_execution_providers(device)\n        case device&.to_s\n        when \"cpu\", nil\n          []\n        when \"cuda\"\n          [\"CUDAExecutionProvider\"]\n        when \"coreml\"\n          [\"CoreMLExecutionProvider\"]\n        else\n          supported_devices = [\"cpu\", \"cuda\", \"coreml\"]\n          raise ArgumentError, \"Unsupported device: #{device}. Should be one of: #{supported_devices.join(\", \")}\"\n        end\n      end\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/configs.rb",
    "content": "module Informers\n  class PretrainedConfig\n    def initialize(config_json)\n      @config_json = config_json.to_h\n    end\n\n    def [](key)\n      @config_json[key.to_s]\n    end\n\n    def []=(key, value)\n      @config_json[key.to_s] = value\n    end\n\n    def to_h\n      @config_json.to_h\n    end\n\n    def self.from_pretrained(\n      pretrained_model_name_or_path,\n      progress_callback: nil,\n      config: nil,\n      cache_dir: nil,\n      local_files_only: false,\n      revision: \"main\",\n      **kwargs\n    )\n      data = config || load_config(\n        pretrained_model_name_or_path,\n        progress_callback:,\n        config:,\n        cache_dir:,\n        local_files_only:,\n        revision:\n      )\n      new(data)\n    end\n\n    def self.load_config(pretrained_model_name_or_path, **options)\n      info = Utils::Hub.get_model_json(pretrained_model_name_or_path, \"config.json\", true, **options)\n      info\n    end\n  end\n\n  class AutoConfig\n    def self.from_pretrained(...)\n      PretrainedConfig.from_pretrained(...)\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/env.rb",
    "content": "module Informers\n  CACHE_HOME = ENV.fetch(\"XDG_CACHE_HOME\", File.join(ENV.fetch(\"HOME\"), \".cache\"))\n  DEFAULT_CACHE_DIR = File.expand_path(File.join(CACHE_HOME, \"informers\"))\n\n  class << self\n    attr_accessor :allow_remote_models, :remote_host, :remote_path_template, :cache_dir\n  end\n\n  self.allow_remote_models = ENV[\"INFORMERS_OFFLINE\"].to_s.empty?\n  self.remote_host = \"https://huggingface.co/\"\n  self.remote_path_template = \"{model}/resolve/{revision}/\"\n\n  self.cache_dir = DEFAULT_CACHE_DIR\nend\n"
  },
  {
    "path": "lib/informers/model.rb",
    "content": "module Informers\n  # TODO remove in 2.0\n  class Model\n    def initialize(model_id, quantized: false)\n      @model = Informers.pipeline(\"embedding\", model_id, quantized: quantized)\n      @options = model_id == \"mixedbread-ai/mxbai-embed-large-v1\" ? {pooling: \"cls\", normalize: false} : {}\n    end\n\n    def embed(texts)\n      @model.(texts, **@options)\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/models.rb",
    "content": "module Informers\n  MODEL_TYPES = {\n    EncoderOnly: 0,\n    EncoderDecoder: 1,\n    Seq2Seq: 2,\n    Vision2Seq: 3,\n    DecoderOnly: 4,\n    MaskGeneration: 5\n  }\n\n  # NOTE: These will be populated fully later\n  MODEL_TYPE_MAPPING = {}\n  MODEL_NAME_TO_CLASS_MAPPING = {}\n  MODEL_CLASS_TO_NAME_MAPPING = {}\n\n  class PretrainedMixin\n    def self.from_pretrained(\n      pretrained_model_name_or_path,\n      quantized: true,\n      progress_callback: nil,\n      config: nil,\n      cache_dir: nil,\n      local_files_only: false,\n      revision: \"main\",\n      device: nil,\n      dtype: nil,\n      model_file_name: nil,\n      session_options: {}\n    )\n      options = {\n        quantized:,\n        progress_callback:,\n        config:,\n        cache_dir:,\n        local_files_only:,\n        revision:,\n        device:,\n        dtype:,\n        model_file_name:,\n        session_options:\n      }\n      config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)\n      if options[:config].nil?\n        # If no config was passed, reuse this config for future processing\n        options[:config] = config\n      end\n\n      if !const_defined?(:MODEL_CLASS_MAPPINGS)\n        raise Error, \"`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: #{name}\"\n      end\n\n      const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|\n        model_info = model_class_mapping[config[:model_type]]\n        if !model_info\n          next # Item not found in this mapping\n        end\n        return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)\n      end\n\n      if const_defined?(:BASE_IF_FAIL)\n        warn \"Unknown model class #{config[:model_type].inspect}, attempting to construct from base class.\"\n        PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)\n      else\n        raise Error, \"Unsupported model type: #{config[:model_type]}\"\n      end\n    end\n  end\n\n  class PreTrainedModel\n    MAIN_INPUT_NAME = :input_ids\n\n    attr_reader :config\n\n    def initialize(config, session)\n      super()\n\n      @config = config\n      @session = session\n\n      @output_names = nil\n\n      model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]\n      model_type = MODEL_TYPE_MAPPING[model_name]\n\n      case model_type\n      when MODEL_TYPES[:DecoderOnly]\n        @can_generate = true\n\n        @run_beam = method(:decoder_run_beam)\n        @get_start_beams = method(:decoder_start_beams)\n        @update_beam = method(:decoder_update_beam)\n        @forward = method(:decoder_forward)\n\n      when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]\n        @can_generate = true\n\n        @run_beam = method(:seq2seq_run_beam)\n        @get_start_beams = method(:seq2seq_start_beams)\n        @update_beam = method(:seq2seq_update_beam)\n        @forward = method(:seq2seq_forward)\n\n      when MODEL_TYPES[:EncoderDecoder]\n        @forward = method(:encoder_forward)\n\n      else\n        @forward = method(:encoder_forward)\n      end\n    end\n\n    def self.from_pretrained(\n      pretrained_model_name_or_path,\n      quantized: true,\n      progress_callback: nil,\n      config: nil,\n      cache_dir: nil,\n      local_files_only: false,\n      revision: \"main\",\n      device: nil,\n      dtype: nil,\n      model_file_name: nil,\n      session_options: {}\n    )\n      options = {\n        quantized:,\n        progress_callback:,\n        config:,\n        cache_dir:,\n        local_files_only:,\n        revision:,\n        device:,\n        dtype:,\n        model_file_name:,\n        session_options:\n      }\n\n      model_name = MODEL_CLASS_TO_NAME_MAPPING[self]\n      model_type = MODEL_TYPE_MAPPING[model_name]\n\n      config ||= AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)\n\n      if model_type == MODEL_TYPES[:DecoderOnly]\n        info = [\n          construct_session(pretrained_model_name_or_path, options[:model_file_name] || \"decoder_model_merged\", **options),\n          Utils::Hub.get_model_json(pretrained_model_name_or_path, \"generation_config.json\", false, **options)\n        ]\n\n      elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]\n        info = [\n          construct_session(pretrained_model_name_or_path, \"encoder_model\", **options),\n          construct_session(pretrained_model_name_or_path, \"decoder_model_merged\", **options),\n          Utils::Hub.get_model_json(pretrained_model_name_or_path, \"generation_config.json\", false, **options)\n        ]\n\n      elsif model_type == MODEL_TYPES[:MaskGeneration]\n        info = [\n          construct_session(pretrained_model_name_or_path, \"vision_encoder\", **options),\n          construct_session(pretrained_model_name_or_path, \"prompt_encoder_mask_decoder\", **options)\n        ]\n\n      elsif model_type == MODEL_TYPES[:EncoderDecoder]\n        info = [\n          construct_session(pretrained_model_name_or_path, \"encoder_model\", **options),\n          construct_session(pretrained_model_name_or_path, \"decoder_model_merged\", **options)\n        ]\n\n      else\n        if model_type != MODEL_TYPES[:EncoderOnly]\n          warn \"Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this.\"\n        end\n        info = [\n          construct_session(pretrained_model_name_or_path, options[:model_file_name] || \"model\", **options)\n        ]\n      end\n\n      new(config, *info)\n    end\n\n    def self.construct_session(pretrained_model_name_or_path, file_name, **options)\n      prefix = \"onnx/\"\n      if file_name.start_with?(\"../\")\n        prefix = \"\"\n        file_name = file_name[3..]\n      elsif file_name.start_with?(\"/\")\n        prefix = \"\"\n        file_name = file_name[1..]\n      end\n      dtype = options[:dtype] || (options[:quantized] ? \"q8\" : \"fp32\")\n      suffix = Utils::DEFAULT_DTYPE_SUFFIX_MAPPING[dtype.to_sym]\n      if !suffix\n        raise ArgumentError, \"Invalid dtype: #{dtype}. Should be one of: #{Utils::DEFAULT_DTYPE_SUFFIX_MAPPING.keys.join(\", \")}\"\n      end\n      model_file_name = \"#{prefix}#{file_name}#{suffix}.onnx\"\n      path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)\n\n      session_options = {\n        providers: Backends::Onnx.device_to_execution_providers(options[:device]),\n        log_severity_level: 4\n      }.merge(options[:session_options] || {})\n\n      begin\n        OnnxRuntime::InferenceSession.new(path, **session_options)\n      rescue OnnxRuntime::Error => e\n        raise e unless e.message.include?(\"No such file or directory\") && e.message.include?(\".onnx_data\")\n\n        Utils::Hub.get_model_file(pretrained_model_name_or_path, \"#{model_file_name}_data\", true, **options)\n        OnnxRuntime::InferenceSession.new(path, **session_options)\n      end\n    end\n\n    def call(model_inputs, **kwargs)\n      @forward.(model_inputs, **kwargs)\n    end\n\n    def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)\n      if !@can_generate\n        model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]\n        error_message = \"The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head.\"\n        raise Error, error_message\n      end\n\n      if !inputs.is_a?(Array)\n        raise ArgumentError, \"`inputs` must be an Array, but is #{inputs.class.name}\"\n      end\n\n      if @config[:is_encoder_decoder]\n        # Generating from the encoder outputs\n        input_ids_seq_length = 0\n      else\n        input_ids_seq_length = inputs.length\n\n        # decoder-only\n        if input_ids_seq_length == 0\n          raise Error, \"Must supply a non-empty array of input token ids.\"\n        end\n      end\n\n      # Update generation config with defaults\n      generation_config = get_generation_config(generation_config)\n\n      logits_processor ||= Utils::LogitsProcessorList.new\n\n      # Update logits processor\n      logits_processor = get_logits_processor(\n        generation_config,\n        input_ids_seq_length,\n        logits_processor\n      )\n\n      eos_token_ids = generation_config[:eos_token_id]\n      if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)\n        eos_token_ids = [eos_token_ids]\n      end\n\n      num_output_tokens = 1\n      max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)\n\n      # Only use max length if max_new_tokens is not provided\n      use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?\n      sampler = Utils::Sampler.get_sampler(generation_config)\n\n      beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)\n\n      while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens\n        newest_beams = []\n        beams.each do |beam|\n          if beam[:done]\n            # Add this beam back into the pool\n            newest_beams << beam\n            next\n          end\n          if use_max_length && beam[:output_token_ids].length >= generation_config[\"max_length\"]\n            # Set this beam to done and add it back into the pool\n            beam[:done] = true\n            newest_beams << beam\n            next\n          end\n\n          output = run_beam(beam)\n\n          # add attentions/scores to beam only if user requested\n          if generation_config[\"output_attentions\"]\n            add_attentions_to_beam(beam, output)\n          end\n\n          # Logits are of the form [batch_size, out_seq_length, vocab_size]\n          # In most cases, this will be [batch_size, 1, vocab_size]\n          # So, we select the last token's logits:\n          # (equivalent to `logits = outputs.logits[:, -1, :]`)\n          logits = output[\"logits\"].map { |v| v[-1] }\n\n          # Apply logits processor\n          logits_processor.(beam[:output_token_ids], logits)\n\n          sampled_tokens = sampler.(logits)\n          sampled_tokens.each do |new_token_id, log_prob|\n            # use previous beam as a starting point\n            new_beam = beam.dup\n\n            # update new beam\n            update_beam(new_beam, new_token_id)\n\n            new_beam[:score] += log_prob\n\n            if eos_token_ids && eos_token_ids.include?(new_token_id)\n              new_beam[:done] = true\n            end\n\n            newest_beams << new_beam\n          end\n        end\n        num_output_tokens += 1\n\n        # Next, we get the best beams, per ID\n        newest_beams =\n          group_beams(newest_beams).map do |group|\n            group.sort_by { |v| -v[:score] }[0...generation_config[\"num_beams\"]]\n          end\n\n        # Flatten beams\n        beams = newest_beams.flatten(1)\n\n        # Run callback\n        if generation_config[\"callback_function\"]\n          generation_config[\"callback_function\"].(beams)\n        end\n      end\n\n      # TODO: Ensure that we can return non-batched outputs\n\n      grouped_beams = group_beams(beams)\n\n      get_flattened = lambda do |key|\n        grouped_beams.flat_map do |batch|\n          if generation_config[\"num_return_sequences\"] > 1\n            raise Todo\n          else\n            [batch[0][key]]\n          end\n        end\n      end\n\n      sequences = get_flattened.(:output_token_ids) # [1, seqLength]\n\n      if generation_config[\"return_dict_in_generate\"]\n        raise Todo\n      else\n        sequences\n      end\n    end\n\n    private\n\n    def get_logits_processor(\n      generation_config,\n      input_ids_seq_length,\n      logits_processor = nil\n    )\n      processors = Utils::LogitsProcessorList.new\n\n      if !generation_config[\"repetition_penalty\"].nil? && generation_config[\"repetition_penalty\"] != 1.0\n        processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config[\"repetition_penalty\"]))\n      end\n\n      if !generation_config[\"no_repeat_ngram_size\"].nil? && generation_config[\"no_repeat_ngram_size\"] > 0\n        processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config[\"no_repeat_ngram_size\"]))\n      end\n\n      if !generation_config[\"bad_words_ids\"].nil?\n        processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config[\"bad_words_ids\"], generation_config[\"eos_token_id\"]))\n      end\n\n      if !generation_config[\"min_length\"].nil? && !generation_config[\"eos_token_id\"].nil? && generation_config[\"min_length\"] > 0\n        processors.push(Utils::MinLengthLogitsProcessor.new(generation_config[\"min_length\"], generation_config[\"eos_token_id\"]))\n      end\n\n      if !generation_config[\"min_new_tokens\"].nil? && !generation_config[\"eos_token_id\"].nil? && generation_config[\"min_new_tokens\"] > 0\n        processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(\n          input_ids_seq_length,\n          generation_config[\"min_new_tokens\"],\n          generation_config[\"eos_token_id\"]\n        ))\n      end\n\n      if !generation_config[\"forced_bos_token_id\"].nil?\n        processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config[\"forced_bos_token_id\"]))\n      end\n\n      if !generation_config[\"forced_eos_token_id\"].nil?\n        processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(\n          generation_config[\"max_length\"],\n          generation_config[\"forced_eos_token_id\"]\n        ))\n      end\n\n      if !generation_config[\"begin_suppress_tokens\"].nil?\n        raise Todo\n      end\n\n      if !generation_config[\"forced_decoder_ids\"].nil?\n        processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config[\"forced_decoder_ids\"]))\n      end\n\n      if !logits_processor.nil?\n        processors.concat(logits_processor)\n      end\n\n      processors\n    end\n\n    def get_generation_config(generation_config)\n      # Create empty generation config (contains defaults)\n      # We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them\n      gen_config = Utils::GenerationConfig.new(@config.to_h)\n\n      # Apply model's generation config, if it exists\n      if @generation_config\n        gen_config.merge!(@generation_config)\n      end\n\n      # Finally, use any generation config specified by the user\n      # when calling `generate`\n      if !generation_config.nil?\n        gen_config.merge!(generation_config)\n      end\n\n      gen_config\n    end\n\n    def seq2seq_forward(model_inputs)\n      encoder_outputs = model_inputs[:encoder_outputs]\n      past_key_values = model_inputs[:past_key_values]\n\n      if !encoder_outputs\n        # Encoder outputs are not given, so we must compute them.\n        encoder_outputs = encoder_forward(model_inputs)[0]\n      end\n      decoder_feeds = {\n        input_ids: model_inputs[:decoder_input_ids],\n        encoder_hidden_states: encoder_outputs\n      }\n      use_cache_branch = !!past_key_values\n\n      if @decoder_merged_session.inputs.map { |v| v[:name] }.include?(\"use_cache_branch\")\n        decoder_feeds[:use_cache_branch] = [use_cache_branch]\n      end\n\n      if @decoder_merged_session.inputs.map { |v| v[:name] }.include?(\"encoder_attention_mask\")\n        decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]\n      end\n\n      prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)\n      add_past_key_values(decoder_feeds, past_key_values)\n\n      decoder_results = session_run(@decoder_merged_session, decoder_feeds)\n      decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h\n      logits = decoder_results[\"logits\"]\n      past_key_values = get_past_key_values(decoder_results, past_key_values)\n\n      # Get cross attention and/or decoder attentions if they are present\n      attns = get_attentions(decoder_results)\n\n      Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns[\"decoder_attentions\"], attns[\"cross_attentions\"])\n    end\n\n    def prepare_position_ids(session, feeds, use_cache_branch)\n      if !session.inputs.map { |v| v[:name] }.include?(\"position_ids\")\n        return\n      end\n\n      raise Todo\n    end\n\n    def get_past_key_values(decoder_results, past_key_values)\n      pkvs = {}\n\n      decoder_results.each_key do |name|\n        if name.start_with?(\"present\")\n          new_name = name.sub(\"present\", \"past_key_values\")\n\n          if past_key_values && name.include?(\"encoder\")\n            # Optimization introduced by optimum to reuse past key values. So, we just replace the constant\n            # outputs with the previous past key values.\n            # https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704\n            pkvs[new_name] = past_key_values[new_name]\n          else\n            pkvs[new_name] = decoder_results[name]\n          end\n        end\n      end\n      pkvs\n    end\n\n    def get_attentions(decoder_results)\n      attns = {}\n\n      [\"cross_attentions\", \"decoder_attentions\"].each do |attn_name|\n        result = []\n        decoder_results.each_key do |name|\n          if name.start_with?(attn_name)\n            index = name.split(\".\").pop\n            result[index] = decoder_results[name]\n          end\n        end\n        attns[attn_name] = result\n      end\n      attns\n    end\n\n    def add_past_key_values(decoder_feeds, past_key_values)\n      if past_key_values\n        decoder_feeds.merge!(past_key_values)\n      else\n        # TODO support batches (i.e., batch_size > 1)\n        batch_size = 1\n\n        if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)\n          _encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]\n          _decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]\n          @num_decoder_layers.times do |i|\n            # decoder_feeds[\"past_key_values.#{i}.encoder.key\"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)\n            # decoder_feeds[\"past_key_values.#{i}.encoder.value\"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)\n            # decoder_feeds[\"past_key_values.#{i}.decoder.key\"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)\n            # decoder_feeds[\"past_key_values.#{i}.decoder.value\"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)\n          end\n        elsif @config[:model_type] == \"falcon\"\n          raise Todo\n        elsif @config[:multi_query]\n          raise Todo\n        elsif @config[:model_type] == \"bloom\"\n          raise Todo\n        else\n          _dims = [batch_size, @num_heads, 0, @dim_kv]\n          @num_layers.times do |i|\n            # decoder_feeds[\"past_key_values.#{i}.key\"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)\n            # decoder_feeds[\"past_key_values.#{i}.value\"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)\n          end\n        end\n      end\n    end\n\n    def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)\n      beams = []\n      beam_id = 0\n\n      requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true\n\n      # decoder_input_ids == output_token_ids\n      decoder_input_ids =\n        generation_config[\"decoder_input_ids\"] ||\n        generation_config[\"decoder_start_token_id\"] ||\n        generation_config[\"bos_token_id\"] ||\n        generation_config[\"eos_token_id\"]\n\n      if !decoder_input_ids.is_a?(Array)\n        decoder_input_ids = [decoder_input_ids]\n      end\n\n      input_token_ids.each do |tokens|\n        # TODO: Improve\n        # Currently, just add back batch dimension.\n        # In future, allow for true parallel execution\n        tokens = [tokens]\n\n        # Create beam\n        start = {\n          inputs: tokens,\n          encoder_outputs: nil,\n          prev_model_outputs: nil,\n\n          output_token_ids: decoder_input_ids,\n          done: false,\n          score: 0,\n          id: beam_id # assign unique id to beams\n        }\n        beam_id += 1\n\n        if requires_attention_mask\n          start[:attention_mask] = prepare_attention_mask(tokens)\n        end\n\n        beams << start\n      end\n\n      beams\n    end\n\n    def prepare_attention_mask(tokens)\n      # Prepare attention mask\n      pad_token_id = @config[\"pad_token_id\"]\n      eos_token_id = @config[\"eos_token_id\"]\n      if eos_token_id.is_a?(Integer)\n        eos_token_id = [eos_token_id]\n      end\n\n      is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?\n      is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)\n\n      if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id\n        raise Todo\n      else\n        Utils.ones_like(tokens)\n      end\n    end\n\n    def seq2seq_run_beam(beam)\n      input_name = self.class.const_get(:MAIN_INPUT_NAME)\n\n      decoder_input_ids = beam[:output_token_ids]\n      if beam[:prev_model_outputs]\n        # After the first step, `prev_model_outputs` won't be null.\n        # So, we cut decoder_input_ids if past is used\n        decoder_input_ids = [decoder_input_ids[-1]]\n      end\n\n      # 1. Prepare\n      model_inputs = {\n        input_name => beam[:inputs],\n        decoder_input_ids: [decoder_input_ids],\n        encoder_outputs: beam[:encoder_outputs],\n        past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]\n      }\n      if beam[:attention_mask]\n        model_inputs[:attention_mask] = beam[:attention_mask]\n      end\n\n      # 2. Run\n      output = @forward.(model_inputs)\n\n      # 3. Update\n      beam[:prev_model_outputs] = output\n      beam[:encoder_outputs] = output[:encoder_outputs]\n\n      output\n    end\n\n    def seq2seq_update_beam(beam, new_token_id)\n      beam[:output_token_ids] += [new_token_id]\n    end\n\n    def group_beams(beams)\n      # Group beams by their ids\n      groups = {}\n      beams.each do |obj|\n        if !groups[obj[:id]]\n          groups[obj[:id]] = [obj]\n        else\n          groups[obj[:id]] << obj\n        end\n      end\n      groups.values\n    end\n\n    def encoder_forward(model_inputs, output_names: nil)\n      encoder_feeds = {}\n      @session.inputs.each do |input|\n        key = input[:name].to_sym\n        encoder_feeds[key] = model_inputs[key]\n      end\n      if @session.inputs.any? { |v| v[:name] == \"token_type_ids\" } && !encoder_feeds[:token_type_ids]\n        raise Todo\n      end\n      session_run(@session, encoder_feeds, output_names:)\n    end\n\n    def decoder_forward(model_inputs)\n      input_ids, past_key_values, attention_mask =\n        model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)\n      decoder_feeds = {\n        input_ids: input_ids,\n        attention_mask: attention_mask || prepare_attention_mask(input_ids)\n      }\n      use_cache_branch = !!past_key_values\n\n      if @session.inputs.map { |v| v[:name] }.include?(\"use_cache_branch\")\n        decoder_feeds[:use_cache_branch] = [use_cache_branch]\n      end\n\n      prepare_position_ids(@session, decoder_feeds, use_cache_branch)\n\n      add_past_key_values(decoder_feeds, past_key_values)\n\n      decoder_results = session_run(@session, decoder_feeds)\n      decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h\n\n      logits = decoder_results[\"logits\"]\n\n      past_key_values = get_past_key_values(decoder_results, past_key_values)\n      {\"logits\" => logits, past_key_values: past_key_values}\n    end\n\n    def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)\n      beams = []\n\n      beam_id = 0\n      input_token_ids.each do |tokens|\n        output_token_ids = tokens.dup\n\n        # TODO: Improve\n        # Currently, just add back batch dimension.\n        # In future, allow for true parallel execution\n        tokens = [tokens]\n\n        if inputs_attention_mask\n          attn_mask = inputs_attention_mask[beam_id]\n          attn_mask = [attn_mask]\n        else\n          attn_mask = prepare_attention_mask(tokens)\n        end\n\n        start = {\n          input: tokens,\n          model_input_ids: tokens,\n          attention_mask: attn_mask,\n          prev_model_outputs: nil,\n\n          output_token_ids: output_token_ids,\n          num_output_tokens: num_output_tokens,\n\n          done: false,\n          score: 0,\n          id: beam_id # assign unique id to beams\n        }\n        beam_id += 1\n\n        beams << start\n      end\n      beams\n    end\n\n    def decoder_run_beam(beam)\n      attn_mask_data = Array.new(beam[:output_token_ids].length, 1)\n\n      # 1. Prepare\n      model_inputs = {\n        input_ids: beam[:model_input_ids],\n        attention_mask: [attn_mask_data],\n        past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]\n      }\n\n      # 2. Run\n      output = @forward.(model_inputs)\n\n      # 3. Update\n      beam[:prev_model_outputs] = output\n\n      output\n    end\n\n    def decoder_update_beam(beam, new_token_id)\n      beam[:output_token_ids] += [new_token_id]\n      beam[:model_input_ids] = [[new_token_id]]\n    end\n\n    def session_run(session, inputs, output_names: nil)\n      checked_inputs = validate_inputs(session, inputs)\n      begin\n        output = session.run(output_names || @output_names, checked_inputs)\n        output = replace_tensors(output)\n        output\n      rescue => e\n        raise e\n      end\n    end\n\n    # TODO\n    def replace_tensors(obj)\n      obj\n    end\n\n    # TODO\n    def validate_inputs(session, inputs)\n      inputs\n    end\n\n    def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)\n      @get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)\n    end\n\n    def run_beam(beam)\n      @run_beam.(beam)\n    end\n\n    def update_beam(beam, new_token_id)\n      @update_beam.(beam, new_token_id)\n    end\n  end\n\n  class BertPreTrainedModel < PreTrainedModel\n  end\n\n  class BertModel < BertPreTrainedModel\n  end\n\n  class BertForMaskedLM < BertPreTrainedModel\n    def call(model_inputs)\n      MaskedLMOutput.new(*super(model_inputs))\n    end\n  end\n\n  class BertForSequenceClassification < BertPreTrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class BertForTokenClassification < BertPreTrainedModel\n    def call(model_inputs)\n      TokenClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class ModernBertPreTrainedModel < PreTrainedModel\n  end\n\n  class ModernBertModel < ModernBertPreTrainedModel\n  end\n\n  class ModernBertForMaskedLM < ModernBertPreTrainedModel\n    def call(model_inputs)\n      MaskedLMOutput.new(*super(model_inputs))\n    end\n  end\n\n  class ModernBertForSequenceClassification < ModernBertPreTrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class ModernBertForTokenClassification < ModernBertPreTrainedModel\n    def call(model_inputs)\n      TokenClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class NomicBertPreTrainedModel < PreTrainedModel\n  end\n\n  class NomicBertModel < NomicBertPreTrainedModel\n  end\n\n  class ConvBertPreTrainedModel < PreTrainedModel\n  end\n\n  class ConvBertModel < ConvBertPreTrainedModel\n  end\n\n  class ElectraPreTrainedModel < PreTrainedModel\n  end\n\n  # TODO add ElectraForPreTraining\n  class ElectraModel < ElectraPreTrainedModel\n  end\n\n  class DebertaV2PreTrainedModel < PreTrainedModel\n  end\n\n  class DebertaV2Model < DebertaV2PreTrainedModel\n  end\n\n  class DistilBertPreTrainedModel < PreTrainedModel\n  end\n\n  class DistilBertModel < DistilBertPreTrainedModel\n  end\n\n  class DistilBertForSequenceClassification < DistilBertPreTrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class DistilBertForQuestionAnswering < DistilBertPreTrainedModel\n    def call(model_inputs)\n      QuestionAnsweringModelOutput.new(*super(model_inputs))\n    end\n  end\n\n  class MPNetPreTrainedModel < PreTrainedModel\n  end\n\n  class MPNetModel < MPNetPreTrainedModel\n  end\n\n  class T5PreTrainedModel < PreTrainedModel\n  end\n\n  class T5Model < T5PreTrainedModel\n  end\n\n  class T5ForConditionalGeneration < T5PreTrainedModel\n    def initialize(config, session, decoder_merged_session, generation_config)\n      super(config, session)\n      @decoder_merged_session = decoder_merged_session\n      @generation_config = generation_config\n\n      @num_decoder_layers = @config[:num_decoder_layers]\n      @num_decoder_heads = @config[:num_heads]\n      @decoder_dim_kv = @config[:d_kv]\n\n      @num_encoder_layers = @config[:num_layers]\n      @num_encoder_heads = @config[:num_heads]\n      @encoder_dim_kv = @config[:d_kv]\n    end\n  end\n\n  class BartPretrainedModel < PreTrainedModel\n  end\n\n  class BartModel < BartPretrainedModel\n  end\n\n  class BartForConditionalGeneration < BartPretrainedModel\n    def initialize(config, session, decoder_merged_session, generation_config)\n      super(config, session)\n      @decoder_merged_session = decoder_merged_session\n      @generation_config = generation_config\n\n      @num_decoder_layers = @config[\"decoder_layers\"]\n      @num_decoder_heads = @config[\"decoder_attention_heads\"]\n      @decoder_dim_kv = @config[\"d_model\"] / @num_decoder_heads.to_f\n\n      @num_encoder_layers = @config[\"encoder_layers\"]\n      @num_encoder_heads = @config[\"encoder_attention_heads\"]\n      @encoder_dim_kv = @config[\"d_model\"] / @num_encoder_heads\n    end\n  end\n\n  class BartForSequenceClassification < BartPretrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class MBartPreTrainedModel < PreTrainedModel\n  end\n\n  class MBartModel < MBartPreTrainedModel\n  end\n\n  class MBartForCausalLM < MBartPreTrainedModel\n    attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,\n      :num_encoder_layers, :num_encoder_heads, :encoder_dim_kv\n\n    def initialize(config, decoder_merged_session, generation_config)\n      super(config, decoder_merged_session)\n      @generation_config = generation_config\n\n      @num_decoder_layers = @config[\"decoder_layers\"]\n      @num_decoder_heads = @config[\"decoder_attention_heads\"]\n      @decoder_dim_kv = @config[\"d_model\"] / @num_decoder_heads.to_f\n\n      @num_encoder_layers = @config[\"encoder_layers\"]\n      @num_encoder_heads = @config[\"encoder_attention_heads\"]\n      @encoder_dim_kv = @config[\"d_model\"] / @num_encoder_heads.to_f\n    end\n  end\n\n  class M2M100PreTrainedModel < PreTrainedModel\n  end\n\n  class M2M100Model < M2M100PreTrainedModel\n  end\n\n  class M2M100ForConditionalGeneration < M2M100PreTrainedModel\n    def initialize(config, session, decoder_merged_session, generation_config)\n      super(config, session)\n      @decoder_merged_session = decoder_merged_session\n      @generation_config = generation_config\n\n      @num_decoder_layers = @config[\"decoder_layers\"]\n      @num_decoder_heads = @config[\"decoder_attention_heads\"]\n      @decoder_dim_kv = @config[\"d_model\"] / @num_decoder_heads.to_f\n\n      @num_encoder_layers = @config[\"encoder_layers\"]\n      @num_encoder_heads = @config[\"encoder_attention_heads\"]\n      @encoder_dim_kv = @config[\"d_model\"] / @num_encoder_heads.to_f\n    end\n  end\n\n  class Wav2Vec2PreTrainedModel < PreTrainedModel\n  end\n\n  class Wav2Vec2Model < Wav2Vec2PreTrainedModel\n  end\n\n  class Wav2Vec2ForSequenceClassification < Wav2Vec2PreTrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class RobertaPreTrainedModel < PreTrainedModel\n  end\n\n  class RobertaModel < RobertaPreTrainedModel\n  end\n\n  class RobertaForMaskedLM < RobertaPreTrainedModel\n    def call(model_inputs)\n      MaskedLMOutput.new(*super(model_inputs))\n    end\n  end\n\n  class RobertaForTokenClassification <  RobertaPreTrainedModel\n    def call(model_inputs)\n      TokenClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class RobertaForSequenceClassification < RobertaPreTrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class XLMRobertaPreTrainedModel < PreTrainedModel\n  end\n\n  class XLMRobertaModel < XLMRobertaPreTrainedModel\n  end\n\n  class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class ViTPreTrainedModel < PreTrainedModel\n  end\n\n  class ViTModel < ViTPreTrainedModel\n  end\n\n  class ViTForImageClassification < ViTPreTrainedModel\n    def call(model_inputs)\n      SequenceClassifierOutput.new(*super(model_inputs))\n    end\n  end\n\n  class CLIPPreTrainedModel < PreTrainedModel\n  end\n\n  class CLIPModel < CLIPPreTrainedModel\n  end\n\n  class GPT2PreTrainedModel < PreTrainedModel\n    attr_reader :num_heads, :num_layers, :dim_kv\n\n    def initialize(config, session, generation_config)\n      super(config, session)\n      @generation_config = generation_config\n\n      # config doesn't contain pad_token_id, so we assume it is the eos_token_id\n      @config[\"pad_token_id\"] = @config[\"eos_token_id\"]\n\n      @num_heads = @config[\"n_head\"]\n      @num_layers = @config[\"n_layer\"]\n      @dim_kv = @config[\"n_embd\"] / @num_heads.to_f\n    end\n  end\n\n  class GPT2Model < GPT2PreTrainedModel\n  end\n\n  class GPT2LMHeadModel < GPT2PreTrainedModel\n  end\n\n  class OwlViTPreTrainedModel < PreTrainedModel\n  end\n\n  class OwlViTModel < OwlViTPreTrainedModel\n  end\n\n  class OwlViTForObjectDetection < OwlViTPreTrainedModel\n  end\n\n  class DetrPreTrainedModel < PreTrainedModel\n  end\n\n  class DetrModel < DetrPreTrainedModel\n  end\n\n  class DetrForObjectDetection < DetrPreTrainedModel\n    def call(model_inputs)\n      DetrObjectDetectionOutput.new(*super(model_inputs))\n    end\n  end\n\n  class DetrForSegmentation < DetrPreTrainedModel\n    def call(model_inputs)\n      DetrSegmentationOutput.new(*super(model_inputs))\n    end\n  end\n\n  class Swin2SRPreTrainedModel < PreTrainedModel\n  end\n\n  class Swin2SRModel < Swin2SRPreTrainedModel\n  end\n\n  class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel\n  end\n\n  class DPTPreTrainedModel < PreTrainedModel\n  end\n\n  class DPTModel < DPTPreTrainedModel\n  end\n\n  class DPTForDepthEstimation < DPTPreTrainedModel\n  end\n\n  class VisionEncoderDecoderModel < PreTrainedModel\n    MAIN_INPUT_NAME = :pixel_values\n\n    def initialize(config, session, decoder_merged_session, generation_config)\n      super(config, session)\n      @decoder_merged_session = decoder_merged_session\n      @generation_config = generation_config\n\n      # Extract configs\n      encoder_config = @config[\"encoder\"]\n      decoder_config = @config[\"decoder\"]\n\n      # Validate encoder\n      encoder_model_type = encoder_config[\"model_type\"]\n      encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]\n      if !encoder_model\n        warn \"Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this.\"\n      end\n\n      # Validate decoder\n      decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config[\"model_type\"]]\n      if !decoder_model\n        raise Error, \"Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \\\"#{decoder_config[\"model_type\"]}\\\"\"\n      end\n\n      decoder_model_class = decoder_model[1]\n      decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)\n\n      @add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)\n      if @add_encoder_pkv\n        # Decoder is part of an encoder-decoder model\n        @num_decoder_layers = decoder.num_decoder_layers\n        @num_decoder_heads = decoder.num_decoder_heads\n        @decoder_dim_kv = decoder.decoder_dim_kv\n\n        @num_encoder_layers = decoder.num_encoder_layers\n        @num_encoder_heads = decoder.num_encoder_heads\n        @encoder_dim_kv = decoder.encoder_dim_kv\n      else\n        # Decoder is a decoder-only model\n        @num_layers = decoder.num_layers\n        @num_heads = decoder.num_heads\n        @dim_kv = decoder.dim_kv\n      end\n    end\n  end\n\n  class DonutSwinPreTrainedModel < PreTrainedModel\n  end\n\n  class DonutSwinModel < DonutSwinPreTrainedModel\n  end\n\n  class WhisperPreTrainedModel < PreTrainedModel\n  end\n\n  class WhisperModel < WhisperPreTrainedModel\n  end\n\n  class WhisperForConditionalGeneration < WhisperPreTrainedModel\n    REQUIRES_ATTENTION_MASK = false\n    MAIN_INPUT_NAME = :input_features\n\n    def initialize(config, session, decoder_merged_session, generation_config)\n      super(config, session)\n      @decoder_merged_session = decoder_merged_session\n      @generation_config = generation_config\n\n      @num_decoder_layers = @config[\"decoder_layers\"]\n      @num_decoder_heads = @config[\"decoder_attention_heads\"]\n      @decoder_dim_kv = @config[\"d_model\"] / @num_decoder_heads.to_f\n\n      @num_encoder_layers = @config[\"encoder_layers\"]\n      @num_encoder_heads = @config[\"encoder_attention_heads\"]\n      @encoder_dim_kv = @config[\"d_model\"] / @num_encoder_heads.to_f\n    end\n\n    def generate(inputs, generation_config = nil, logits_processor = nil)\n      raise Todo\n    end\n  end\n\n  class VitsPreTrainedModel < PreTrainedModel\n  end\n\n  class VitsModel < VitsPreTrainedModel\n    def call(model_inputs)\n      VitsModelOutput.new(*super(model_inputs))\n    end\n  end\n\n  class SpeechT5PreTrainedModel < PreTrainedModel\n  end\n\n  class SpeechT5Model < SpeechT5PreTrainedModel\n  end\n\n  class SpeechT5ForSpeechToText < SpeechT5PreTrainedModel\n  end\n\n  class SpeechT5ForTextToSpeech < SpeechT5PreTrainedModel\n  end\n\n  class ClapPreTrainedModel < PreTrainedModel\n  end\n\n  class ClapModel < ClapPreTrainedModel\n  end\n\n  MODEL_MAPPING_NAMES_ENCODER_ONLY = {\n    \"bert\" => [\"BertModel\", BertModel],\n    \"modernbert\" => [\"ModernBertModel\", ModernBertModel],\n    \"nomic_bert\" => [\"NomicBertModel\", NomicBertModel],\n    \"electra\" => [\"ElectraModel\", ElectraModel],\n    \"convbert\" => [\"ConvBertModel\", ConvBertModel],\n    \"deberta-v2\" => [\"DebertaV2Model\", DebertaV2Model],\n    \"mpnet\" => [\"MPNetModel\", MPNetModel],\n    \"distilbert\" => [\"DistilBertModel\", DistilBertModel],\n    \"roberta\" => [\"RobertaModel\", RobertaModel],\n    \"xlm-roberta\" => [\"XLMRobertaModel\", XLMRobertaModel],\n    \"clap\" => [\"ClapModel\", ClapModel],\n    \"clip\" => [\"CLIPModel\", CLIPModel],\n    \"detr\" => [\"DetrModel\", DetrModel],\n    \"vit\" => [\"ViTModel\", ViTModel],\n    \"owlvit\" => [\"OwlViTModel\", OwlViTModel],\n    \"donut-swin\" => [\"DonutSwinModel\", DonutSwinModel]\n  }\n\n  MODEL_MAPPING_NAMES_ENCODER_DECODER = {\n    \"bart\" => [\"BartModel\", BartModel]\n  }\n\n  MODEL_MAPPING_NAMES_DECODER_ONLY = {\n  }\n\n  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = {\n    \"whisper\" => [\"WhisperForConditionalGeneration\", WhisperForConditionalGeneration]\n  }\n\n  MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = {\n    \"speecht5\" => [\"SpeechT5ForTextToSpeech\", SpeechT5ForTextToSpeech]\n  }\n\n  MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = {\n    \"vits\" => [\"VitsModel\", VitsModel]\n  }\n\n  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {\n    \"bert\" => [\"BertForSequenceClassification\", BertForSequenceClassification],\n    \"modernbert\" => [\"ModernBertForSequenceClassification\", ModernBertForSequenceClassification],\n    \"distilbert\" => [\"DistilBertForSequenceClassification\", DistilBertForSequenceClassification],\n    \"roberta\" => [\"RobertaForSequenceClassification\", RobertaForSequenceClassification],\n    \"xlm-roberta\" => [\"XLMRobertaForSequenceClassification\", XLMRobertaForSequenceClassification],\n    \"bart\" => [\"BartForSequenceClassification\", BartForSequenceClassification]\n  }\n\n  MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {\n    \"bert\" => [\"BertForTokenClassification\", BertForTokenClassification],\n    \"modernbert\" => [\"ModernBertForTokenClassification\", ModernBertForTokenClassification],\n    \"roberta\" => [\"RobertaForTokenClassification\", RobertaForTokenClassification]\n  }\n\n  MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {\n    \"t5\" => [\"T5ForConditionalGeneration\", T5ForConditionalGeneration],\n    \"bart\" => [\"BartForConditionalGeneration\", BartForConditionalGeneration],\n    \"m2m_100\" => [\"M2M100ForConditionalGeneration\", M2M100ForConditionalGeneration]\n  }\n\n  MODEL_WITH_LM_HEAD_MAPPING_NAMES = {\n    \"gpt2\" => [\"GPT2LMHeadModel\", GPT2LMHeadModel],\n    \"mbart\" => [\"MBartForCausalLM\", MBartForCausalLM]\n  }\n\n  MODEL_FOR_MASKED_LM_MAPPING_NAMES = {\n    \"bert\" => [\"BertForMaskedLM\", BertForMaskedLM],\n    \"modernbert\" => [\"ModernBertForMaskedLM\", ModernBertForMaskedLM],\n    \"roberta\" => [\"RobertaForMaskedLM\", RobertaForMaskedLM]\n  }\n\n  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {\n    \"distilbert\" => [\"DistilBertForQuestionAnswering\", DistilBertForQuestionAnswering]\n  }\n\n  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {\n    \"vision-encoder-decoder\" => [\"VisionEncoderDecoderModel\", VisionEncoderDecoderModel]\n  }\n\n  MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {\n    \"vision-encoder-decoder\" => [\"VisionEncoderDecoderModel\", VisionEncoderDecoderModel]\n  }\n\n  MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {\n    \"vit\" => [\"ViTForImageClassification\", ViTForImageClassification]\n  }\n\n  MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {\n    \"detr\" => [\"DetrForObjectDetection\", DetrForObjectDetection]\n  }\n\n  MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {\n    \"owlvit\" => [\"OwlViTForObjectDetection\", OwlViTForObjectDetection]\n  }\n\n  MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {\n    \"detr\" => [\"DetrForSegmentation\", DetrForSegmentation]\n  }\n\n  MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {\n  }\n\n  MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = {\n  }\n\n  MODEL_FOR_CTC_MAPPING_NAMES = {\n  }\n\n  MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = {\n    \"wav2vec2\" => [\"Wav2Vec2ForSequenceClassification\", Wav2Vec2ForSequenceClassification]\n  }\n\n  MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = {\n  }\n\n  MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = {\n  }\n\n  MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = {\n  }\n\n  MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {\n    \"swin2sr\" => [\"Swin2SRForImageSuperResolution\", Swin2SRForImageSuperResolution]\n  }\n\n  MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {\n    \"dpt\" => [\"DPTForDepthEstimation\", DPTForDepthEstimation]\n  }\n\n  MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {\n  }\n\n  MODEL_CLASS_TYPE_MAPPING = [\n    [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],\n    [MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES[:DecoderOnly]],\n    [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],\n    [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],\n    [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],\n    [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],\n    [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES[:MaskGeneration]],\n    [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],\n    [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],\n    [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]\n  ]\n\n  MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|\n    mappings.values.each do |name, model|\n      MODEL_TYPE_MAPPING[name] = type\n      MODEL_CLASS_TO_NAME_MAPPING[model] = name\n      MODEL_NAME_TO_CLASS_MAPPING[name] = model\n    end\n  end\n\n  class AutoModel < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map { |x| x[0] }\n    BASE_IF_FAIL = true\n  end\n\n  class AutoModelForSequenceClassification < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForTokenClassification < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForSeq2SeqLM < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]\n  end\n\n  class AutoModelForSpeechSeq2Seq < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES]\n  end\n\n  class AutoModelForTextToSpectrogram < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES]\n  end\n\n  class AutoModelForTextToWaveform < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES]\n  end\n\n  class AutoModelForCausalLM < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]\n  end\n\n  class AutoModelForMaskedLM < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]\n  end\n\n  class AutoModelForQuestionAnswering < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]\n  end\n\n  class AutoModelForVision2Seq < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]\n  end\n\n  class AutoModelForImageClassification < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForImageSegmentation < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForSemanticSegmentation < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForObjectDetection < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]\n  end\n\n  class AutoModelForZeroShotObjectDetection < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]\n  end\n\n  class AutoModelForMaskGeneration < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForCTC < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_CTC_MAPPING_NAMES]\n  end\n\n  class AutoModelForAudioClassification < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForXVector < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES]\n  end\n\n  class AutoModelForAudioFrameClassification < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForDocumentQuestionAnswering < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]\n  end\n\n  class AutoModelForImageMatting < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES]\n  end\n\n  class AutoModelForImageToImage < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]\n  end\n\n  class AutoModelForDepthEstimation < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]\n  end\n\n  class AutoModelForImageFeatureExtraction < PretrainedMixin\n    MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]\n  end\n\n  class ModelOutput\n    def [](key)\n      instance_variable_get(\"@#{key}\")\n    end\n  end\n\n  class Seq2SeqLMOutput < ModelOutput\n    def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)\n      super()\n      @logits = logits\n      @past_key_values = past_key_values\n      @encoder_outputs = encoder_outputs\n      @decoder_attentions = decoder_attentions\n      @cross_attentions = cross_attentions\n    end\n  end\n\n  class SequenceClassifierOutput < ModelOutput\n    attr_reader :logits\n\n    def initialize(logits)\n      super()\n      @logits = logits\n    end\n  end\n\n  class TokenClassifierOutput < ModelOutput\n    attr_reader :logits\n\n    def initialize(logits)\n      super()\n      @logits = logits\n    end\n  end\n\n  class MaskedLMOutput < ModelOutput\n    attr_reader :logits\n\n    def initialize(logits)\n      super()\n      @logits = logits\n    end\n  end\n\n  class QuestionAnsweringModelOutput < ModelOutput\n    attr_reader :start_logits, :end_logits\n\n    def initialize(start_logits, end_logits)\n      super()\n      @start_logits = start_logits\n      @end_logits = end_logits\n    end\n  end\n\n  class DetrObjectDetectionOutput < ModelOutput\n    attr_reader :logits, :pred_boxes\n\n    def initialize(logits, pred_boxes)\n      super()\n      @logits = logits\n      @pred_boxes = pred_boxes\n    end\n  end\n\n  class DetrSegmentationOutput < ModelOutput\n    attr_reader :logits, :pred_boxes, :pred_masks\n\n    def initialize(logits, pred_boxes, pred_masks)\n      super()\n      @logits = logits\n      @pred_boxes = pred_boxes\n      @pred_masks = pred_masks\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/pipelines.rb",
    "content": "module Informers\n  class Pipeline\n    def initialize(task:, model:, tokenizer: nil, processor: nil)\n      super()\n      @task = task\n      @model = model\n      @tokenizer = tokenizer\n      @processor = processor\n    end\n\n    private\n\n    def prepare_images(images)\n      if !images.is_a?(Array)\n        images = [images]\n      end\n\n      # Possibly convert any non-images to images\n      images.map { |x| Utils::RawImage.read(x) }\n    end\n\n    def prepare_audios(audios, sampling_rate)\n      if !audios.is_a?(Array)\n        audios = [audios]\n      end\n\n      audios.map do |x|\n        if x.is_a?(String) || x.is_a?(URI)\n          Utils.read_audio(x, sampling_rate)\n        else\n          x\n        end\n      end\n    end\n\n    def get_bounding_box(box, as_integer)\n      if as_integer\n        box = box.map { |x| x.to_i }\n      end\n      xmin, ymin, xmax, ymax = box\n\n      {xmin:, ymin:, xmax:, ymax:}\n    end\n  end\n\n  class TextClassificationPipeline < Pipeline\n    def call(texts, top_k: 1)\n      # Run tokenization\n      model_inputs = @tokenizer.(texts,\n        padding: true,\n        truncation: true\n      )\n\n      # Run model\n      outputs = @model.(model_inputs)\n\n      function_to_apply =\n        if @model.config[:problem_type] == \"multi_label_classification\"\n          ->(batch) { Utils.sigmoid(batch) }\n        else\n          ->(batch) { Utils.softmax(batch) } # single_label_classification (default)\n        end\n\n      id2label = @model.config[:id2label]\n\n      to_return = []\n      outputs.logits.each do |batch|\n        output = function_to_apply.(batch)\n        scores = Utils.get_top_items(output, top_k)\n\n        vals = scores.map do |x|\n          {\n            label: id2label[x[0].to_s],\n            score: x[1]\n          }\n        end\n        if top_k == 1\n          to_return.concat(vals)\n        else\n          to_return << vals\n        end\n      end\n\n      texts.is_a?(Array) ? to_return : to_return[0]\n    end\n  end\n\n  class TokenClassificationPipeline < Pipeline\n    def call(\n      texts,\n      ignore_labels: [\"O\"],\n      aggregation_strategy: \"simple\"\n    )\n      is_batched = texts.is_a?(Array)\n\n      # Run tokenization\n      model_inputs = @tokenizer.(is_batched ? texts : [texts],\n        padding: true,\n        truncation: true,\n        return_offsets: true\n      )\n\n      # Run model\n      outputs = @model.(model_inputs)\n\n      logits = outputs.logits\n      id2label = @model.config[:id2label]\n\n      to_return = []\n      logits.length.times do |i|\n        ids = model_inputs[:input_ids][i]\n        batch = logits[i]\n        offsets = model_inputs[:offsets][i]\n\n        # List of tokens that aren't ignored\n        tokens = []\n        batch.length.times do |j|\n          token_data = batch[j]\n          top_score_index = Utils.max(token_data)[1]\n\n          entity = id2label ? id2label[top_score_index.to_s] : \"LABEL_#{top_score_index}\"\n          if ignore_labels.include?(entity)\n            # We predicted a token that should be ignored. So, we skip it.\n            next\n          end\n\n          # TODO add option to keep special tokens?\n          word = @tokenizer.decode([ids[j]], skip_special_tokens: true)\n          if word == \"\"\n            # Was a special token. So, we skip it.\n            next\n          end\n\n          scores = Utils.softmax(token_data)\n\n          tokens << {\n            entity: entity,\n            score: scores[top_score_index],\n            index: j,\n            word: word,\n            start: offsets[j][0],\n            end: offsets[j][1]\n          }\n        end\n\n        case aggregation_strategy\n        when \"simple\"\n          tokens = group_entities(tokens)\n        when \"none\"\n          # do nothing\n        else\n          raise ArgumentError, \"Invalid aggregation_strategy\"\n        end\n\n        to_return << tokens\n      end\n      is_batched ? to_return : to_return[0]\n    end\n\n    def group_sub_entities(entities)\n      # Get the first entity in the entity group\n      entity = entities[0][:entity].split(\"-\", 2)[-1]\n      scores = entities.map { |entity| entity[:score] }\n      tokens = entities.map { |entity| entity[:word] }\n\n      entity_group = {\n        entity_group: entity,\n        score: scores.sum / scores.count.to_f,\n        word: @tokenizer.convert_tokens_to_string(tokens),\n        start: entities[0][:start],\n        end: entities[-1][:end]\n      }\n      entity_group\n    end\n\n    def get_tag(entity_name)\n      if entity_name.start_with?(\"B-\")\n        bi = \"B\"\n        tag = entity_name[2..]\n      elsif entity_name.start_with?(\"I-\")\n        bi = \"I\"\n        tag = entity_name[2..]\n      else\n        # It's not in B-, I- format\n        # Default to I- for continuation.\n        bi = \"I\"\n        tag = entity_name\n      end\n      [bi, tag]\n    end\n\n    def group_entities(entities)\n      entity_groups = []\n      entity_group_disagg = []\n\n      entities.each do |entity|\n        if entity_group_disagg.empty?\n          entity_group_disagg << entity\n          next\n        end\n\n        # If the current entity is similar and adjacent to the previous entity,\n        # append it to the disaggregated entity group\n        # The split is meant to account for the \"B\" and \"I\" prefixes\n        # Shouldn't merge if both entities are B-type\n        bi, tag = get_tag(entity[:entity])\n        _last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])\n\n        if tag == last_tag && bi != \"B\"\n          # Modify subword type to be previous_type\n          entity_group_disagg << entity\n        else\n          # If the current entity is different from the previous entity\n          # aggregate the disaggregated entity group\n          entity_groups << group_sub_entities(entity_group_disagg)\n          entity_group_disagg = [entity]\n        end\n      end\n      if entity_group_disagg.any?\n        # it's the last entity, add it to the entity groups\n        entity_groups << group_sub_entities(entity_group_disagg)\n      end\n\n      entity_groups\n    end\n  end\n\n  class QuestionAnsweringPipeline < Pipeline\n    def call(question, context, top_k: 1)\n      # Run tokenization\n      inputs = @tokenizer.(question,\n        text_pair: context,\n        padding: true,\n        truncation: true,\n        return_offsets: true\n      )\n\n      output = @model.(inputs)\n\n      to_return = []\n      output.start_logits.length.times do |j|\n        ids = inputs[:input_ids][j]\n        sep_index = ids.index(@tokenizer.sep_token_id)\n        offsets = inputs[:offsets][j]\n\n        s1 = Utils.softmax(output.start_logits[j])\n          .map.with_index\n          .select { |x| x[1] > sep_index }\n        e1 = Utils.softmax(output.end_logits[j])\n          .map.with_index\n          .select { |x| x[1] > sep_index }\n\n        options = s1.product(e1)\n          .select { |x| x[0][1] <= x[1][1] }\n          .map { |x| [x[0][1], x[1][1], x[0][0] * x[1][0]] }\n          .sort_by { |v| -v[2] }\n\n        [options.length, top_k].min.times do |k|\n          start, end_, score = options[k]\n\n          answer_tokens = ids.slice(start, end_ + 1)\n\n          answer = @tokenizer.decode(answer_tokens,\n            skip_special_tokens: true\n          )\n\n          to_return << {\n            answer:,\n            score:,\n            start: offsets[start][0],\n            end: offsets[end_][1]\n          }\n        end\n      end\n\n      question.is_a?(Array) ? to_return : to_return[0]\n    end\n  end\n\n  class FillMaskPipeline < Pipeline\n    def call(texts, top_k: 5)\n      model_inputs = @tokenizer.(texts, padding: true, truncation: true)\n      outputs = @model.(model_inputs)\n\n      to_return = []\n      model_inputs[:input_ids].each_with_index do |ids, i|\n        mask_token_index = ids.index(@tokenizer.mask_token_id)\n\n        if mask_token_index.nil?\n          raise ArgumentError, \"Mask token (#{@tokenizer.mask_token}) not found in text.\"\n        end\n        logits = outputs.logits[i]\n        item_logits = logits[mask_token_index]\n\n        scores = Utils.get_top_items(Utils.softmax(item_logits), top_k)\n\n        to_return <<\n          scores.map do |x|\n            sequence = ids.dup\n            sequence[mask_token_index] = x[0]\n\n            {\n              score: x[1],\n              token: x[0],\n              token_str: @tokenizer.id_to_token(x[0]),\n              sequence: @tokenizer.decode(sequence, skip_special_tokens: true)\n            }\n          end\n      end\n      texts.is_a?(Array) ? to_return : to_return[0]\n    end\n  end\n\n  class Text2TextGenerationPipeline < Pipeline\n    KEY = :generated_text\n\n    def call(texts, **generate_kwargs)\n      if !texts.is_a?(Array)\n        texts = [texts]\n      end\n\n      # Add global prefix, if present\n      if @model.config[:prefix]\n        texts = texts.map { |x| @model.config[:prefix] + x }\n      end\n\n      # Handle task specific params:\n      task_specific_params = @model.config[:task_specific_params]\n      if task_specific_params && task_specific_params[@task]\n        # Add prefixes, if present\n        if task_specific_params[@task][\"prefix\"]\n          texts = texts.map { |x| task_specific_params[@task][\"prefix\"] + x }\n        end\n\n        # TODO update generation config\n      end\n\n      tokenizer = @tokenizer\n      tokenizer_options = {\n        padding: true,\n        truncation: true\n      }\n      if is_a?(TranslationPipeline) && tokenizer.respond_to?(:_build_translation_inputs)\n        input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs)[:input_ids]\n      else\n        input_ids = tokenizer.(texts, **tokenizer_options)[:input_ids]\n      end\n\n      output_token_ids = @model.generate(input_ids, generate_kwargs)\n\n      tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)\n        .map { |text| {self.class.const_get(:KEY) => text} }\n    end\n  end\n\n  class SummarizationPipeline < Text2TextGenerationPipeline\n    KEY = :summary_text\n  end\n\n  class TranslationPipeline < Text2TextGenerationPipeline\n    KEY = :translation_text\n  end\n\n  class TextGenerationPipeline < Pipeline\n    def call(texts, **generate_kwargs)\n      is_batched = false\n      is_chat_input = false\n\n      # Normalize inputs\n      if texts.is_a?(String)\n        texts = [texts]\n        inputs = texts\n      else\n        raise Todo\n      end\n\n      # By default, do not add special tokens\n      add_special_tokens = generate_kwargs[:add_special_tokens] || false\n\n      # /By default, return full text\n      return_full_text =\n        if is_chat_input\n          false\n        else\n          generate_kwargs[:return_full_text] || true\n        end\n\n      @tokenizer.padding_side = \"left\"\n      input_ids, attention_mask =\n        @tokenizer.(inputs, add_special_tokens:, padding: true, truncation: true)\n          .values_at(:input_ids, :attention_mask)\n\n      output_token_ids =\n        @model.generate(\n          input_ids, generate_kwargs, nil, inputs_attention_mask: attention_mask\n        )\n\n      decoded = @tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)\n\n      if !return_full_text && Utils.dims(input_ids)[-1] > 0\n        prompt_lengths = @tokenizer.batch_decode(input_ids, skip_special_tokens: true).map { |x| x.length }\n      end\n\n      to_return = Array.new(texts.length) { [] }\n      decoded.length.times do |i|\n        text_index = (i / output_token_ids.length.to_i * texts.length).floor\n\n        if prompt_lengths\n          raise Todo\n        end\n        # TODO is_chat_input\n        to_return[text_index] << {\n          generated_text: decoded[i]\n        }\n      end\n      !is_batched && to_return.length == 1 ? to_return[0] : to_return\n    end\n  end\n\n  class ZeroShotClassificationPipeline < Pipeline\n    def initialize(**options)\n      super(**options)\n\n      @label2id = @model.config[:label2id].transform_keys(&:downcase)\n\n      @entailment_id = @label2id[\"entailment\"]\n      if @entailment_id.nil?\n        warn \"Could not find 'entailment' in label2id mapping. Using 2 as entailment_id.\"\n        @entailment_id = 2\n      end\n\n      @contradiction_id = @label2id[\"contradiction\"] || @label2id[\"not_entailment\"]\n      if @contradiction_id.nil?\n        warn \"Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id.\"\n        @contradiction_id = 0\n      end\n    end\n\n    def call(texts, candidate_labels, hypothesis_template: \"This example is {}.\", multi_label: false)\n      is_batched = texts.is_a?(Array)\n      if !is_batched\n        texts = [texts]\n      end\n      if !candidate_labels.is_a?(Array)\n        candidate_labels = [candidate_labels]\n      end\n\n      # Insert labels into hypothesis template\n      hypotheses = candidate_labels.map { |x| hypothesis_template.sub(\"{}\", x) }\n\n      # How to perform the softmax over the logits:\n      #  - true:  softmax over the entailment vs. contradiction dim for each label independently\n      #  - false: softmax the \"entailment\" logits over all candidate labels\n      softmax_each = multi_label || candidate_labels.length == 1\n\n      to_return = []\n      texts.each do |premise|\n        entails_logits = []\n\n        hypotheses.each do |hypothesis|\n          inputs = @tokenizer.(\n            premise,\n            text_pair: hypothesis,\n            padding: true,\n            truncation: true\n          )\n          outputs = @model.(inputs)\n\n          if softmax_each\n            entails_logits << [\n              outputs.logits[0][@contradiction_id],\n              outputs.logits[0][@entailment_id]\n            ]\n          else\n            entails_logits << outputs.logits[0][@entailment_id]\n          end\n        end\n\n        scores =\n          if softmax_each\n            entails_logits.map { |x| Utils.softmax(x)[1] }\n          else\n            Utils.softmax(entails_logits)\n          end\n\n        # Sort by scores (desc) and return scores with indices\n        scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }\n\n        to_return << {\n          sequence: premise,\n          labels: scores_sorted.map { |x| candidate_labels[x[1]] },\n          scores: scores_sorted.map { |x| x[0] }\n        }\n      end\n      is_batched ? to_return : to_return[0]\n    end\n  end\n\n  class ImageToTextPipeline < Pipeline\n    def call(images, **generate_kwargs)\n      is_batched = images.is_a?(Array)\n      prepared_images = prepare_images(images)\n\n      pixel_values = @processor.(prepared_images)[:pixel_values]\n\n      to_return = []\n      pixel_values.each do |batch|\n        batch = [batch]\n        output = @model.generate(batch, **generate_kwargs)\n        decoded = @tokenizer\n          .batch_decode(output, skip_special_tokens: true)\n          .map { |x| {generated_text: x.strip} }\n        to_return << decoded\n      end\n\n      is_batched ? to_return : to_return[0]\n    end\n  end\n\n  class ImageClassificationPipeline < Pipeline\n    def call(images, top_k: 1)\n      is_batched = images.is_a?(Array)\n      prepared_images = prepare_images(images)\n\n      pixel_values = @processor.(prepared_images)[:pixel_values]\n      output = @model.({pixel_values: pixel_values})\n\n      id2label = @model.config[:id2label]\n      to_return = []\n      output.logits.each do |batch|\n        scores = Utils.get_top_items(Utils.softmax(batch), top_k)\n\n        vals =\n          scores.map do |x|\n            {\n              label: id2label[x[0].to_s],\n              score: x[1]\n            }\n          end\n        if top_k == 1\n          to_return.push(*vals)\n        else\n          to_return << vals\n        end\n      end\n\n      is_batched || top_k == 1 ? to_return : to_return[0]\n    end\n  end\n\n  class ImageSegmentationPipeline < Pipeline\n    def initialize(**options)\n      super(**options)\n\n      @subtasks_mapping = {\n        \"panoptic\" => \"post_process_panoptic_segmentation\",\n        \"instance\" => \"post_process_instance_segmentation\",\n        \"semantic\" => \"post_process_semantic_segmentation\"\n      }\n    end\n\n    def call(\n      images,\n      threshold: 0.5,\n      mask_threshold: 0.5,\n      overlap_mask_area_threshold: 0.8,\n      label_ids_to_fuse: nil,\n      target_sizes: nil,\n      subtask: nil\n    )\n      is_batched = images.is_a?(Array)\n\n      if is_batched && images.length != 1\n        raise Error, \"Image segmentation pipeline currently only supports a batch size of 1.\"\n      end\n\n      prepared_images = prepare_images(images)\n      image_sizes = prepared_images.map { |x| [x.height, x.width] }\n\n      model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)\n      output = @model.(model_inputs)\n\n      if !subtask.nil?\n        fn = @subtasks_mapping[subtask]\n      else\n        @subtasks_mapping.each do |task, func|\n          if @processor.feature_extractor.respond_to?(func)\n            fn = @processor.feature_extractor.method(func)\n            subtask = task\n            break\n          end\n        end\n      end\n\n      id2label = @model.config[:id2label]\n\n      annotation = []\n      if subtask == \"panoptic\" || subtask == \"instance\"\n        processed = fn.(\n          output,\n          threshold:,\n          mask_threshold:,\n          overlap_mask_area_threshold:,\n          label_ids_to_fuse:,\n          target_sizes: target_sizes || image_sizes, # TODO FIX?\n        )[0]\n\n        _segmentation = processed[:segmentation]\n\n        processed[:segments_info].each do |segment|\n          annotation << {\n            label: id2label[segment[:label_id].to_s],\n            score: segment[:score]\n            # TODO mask\n          }\n        end\n      elsif subtask == \"semantic\"\n        raise Todo\n      else\n        raise Error, \"Subtask #{subtask} not supported.\"\n      end\n\n      annotation\n    end\n  end\n\n  class ZeroShotImageClassificationPipeline < Pipeline\n    def call(images, candidate_labels, hypothesis_template: \"This is a photo of {}\")\n      is_batched = images.is_a?(Array)\n      prepared_images = prepare_images(images)\n\n      # Insert label into hypothesis template\n      texts = candidate_labels.map { |x| hypothesis_template.sub(\"{}\", x) }\n\n      #  Run tokenization\n      text_inputs = @tokenizer.(texts,\n        padding: @model.config[:model_type] == \"siglip\" ? \"max_length\" : true,\n        truncation: true\n      )\n\n      # Run processor\n      pixel_values = @processor.(prepared_images)[:pixel_values]\n\n      # Run model with both text and pixel inputs\n      output = @model.(text_inputs.merge(pixel_values: pixel_values))\n\n      function_to_apply =\n        if @model.config[:model_type] == \"siglip\"\n          ->(batch) { Utils.sigmoid(batch) }\n        else\n          ->(batch) { Utils.softmax(batch) }\n        end\n\n      # Compare each image with each candidate label\n      to_return = []\n      output[0].each do |batch|\n        # Compute softmax per image\n        probs = function_to_apply.(batch)\n\n        result = probs\n          .map.with_index { |x, i| {label: candidate_labels[i], score: x} }\n          .sort_by { |v| -v[:score] }\n\n        to_return << result\n      end\n\n      is_batched ? to_return : to_return[0]\n    end\n  end\n\n  class ObjectDetectionPipeline < Pipeline\n    def call(images, threshold: 0.9, percentage: false)\n      is_batched = images.is_a?(Array)\n\n      if is_batched && images.length != 1\n        raise Error, \"Object detection pipeline currently only supports a batch size of 1.\"\n      end\n      prepared_images = prepare_images(images)\n\n      image_sizes = percentage ? nil : prepared_images.map { |x| [x.height, x.width] }\n\n      model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)\n      output = @model.(model_inputs)\n\n      processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_sizes)\n\n      # Add labels\n      id2label = @model.config[:id2label]\n\n      # Format output\n      result =\n        processed.map do |batch|\n          batch[:boxes].map.with_index do |box, i|\n            {\n              label: id2label[batch[:classes][i].to_s],\n              score: batch[:scores][i],\n              box: get_bounding_box(box, !percentage)\n            }\n          end.sort_by { |v| -v[:score] }\n        end\n\n      is_batched ? result : result[0]\n    end\n  end\n\n  class ZeroShotObjectDetectionPipeline < Pipeline\n    def call(\n      images,\n      candidate_labels,\n      threshold: 0.1,\n      top_k: nil,\n      percentage: false\n    )\n      is_batched = images.is_a?(Array)\n      prepared_images = prepare_images(images)\n\n      # Run tokenization\n      text_inputs = @tokenizer.(candidate_labels,\n        padding: true,\n        truncation: true\n      )\n\n      # Run processor\n      model_inputs = @processor.(prepared_images)\n\n      # Since non-maximum suppression is performed for exporting, we need to\n      # process each image separately. For more information, see:\n      # https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032\n      to_return = []\n      prepared_images.length.times do |i|\n        image = prepared_images[i]\n        image_size = percentage ? nil : [[image.height, image.width]]\n        pixel_values = [model_inputs[:pixel_values][i]]\n\n        # Run model with both text and pixel inputs\n        output = @model.(text_inputs.merge(pixel_values: pixel_values))\n        # TODO remove\n        output = @model.instance_variable_get(:@session).outputs.map { |v| v[:name].to_sym }.zip(output).to_h\n\n        processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_size, true)[0]\n        result =\n          processed[:boxes].map.with_index do |box, i|\n            {\n              label: candidate_labels[processed[:classes][i]],\n              score: processed[:scores][i],\n              box: get_bounding_box(box, !percentage)\n            }\n          end\n        result.sort_by! { |v| -v[:score] }\n        if !top_k.nil?\n          result = result[0...topk]\n        end\n        to_return << result\n      end\n\n      is_batched ? to_return : to_return[0]\n    end\n  end\n\n  class DocumentQuestionAnsweringPipeline < Pipeline\n    def call(image, question, **generate_kwargs)\n      # NOTE: For now, we only support a batch size of 1\n\n      # Preprocess image\n      prepared_image = prepare_images(image)[0]\n      pixel_values = @processor.(prepared_image)[:pixel_values]\n\n      # Run tokenization\n      task_prompt = \"<s_docvqa><s_question>#{question}</s_question><s_answer>\"\n      decoder_input_ids =\n        @tokenizer.(\n          task_prompt,\n          add_special_tokens: false,\n          padding: true,\n          truncation: true\n        )[:input_ids]\n\n      # Run model\n      output =\n        @model.generate(\n          pixel_values,\n          generate_kwargs.merge(\n            decoder_input_ids: decoder_input_ids[0],\n            max_length: @model.config[\"decoder\"][\"max_position_embeddings\"]\n          ).transform_keys(&:to_s)\n        )\n\n      # Decode output\n      decoded = @tokenizer.batch_decode(output, skip_special_tokens: false)[0]\n\n      # Parse answer\n      match = decoded.match(/<s_answer>(.*?)<\\/s_answer>/)\n      answer = nil\n      if match && match.length >= 2\n        answer = match[1].strip\n      end\n      [{answer:}]\n    end\n  end\n\n  class TextToAudioPipeline < Pipeline\n    DEFAULT_VOCODER_ID = \"Xenova/speecht5_hifigan\"\n\n    def initialize(**options)\n      super(**options)\n\n      # TODO: Find a better way for `pipeline` to set the default vocoder\n      @vocoder = options[:vocoder]\n    end\n\n    def call(text_inputs, speaker_embeddings: nil)\n      # If this.processor is not set, we are using a `AutoModelForTextToWaveform` model\n      if @processor\n        call_text_to_spectrogram(text_inputs, speaker_embeddings:)\n      else\n        call_text_to_waveform(text_inputs)\n      end\n    end\n  end\n\n  class FeatureExtractionPipeline < Pipeline\n    def call(\n      texts,\n      pooling: \"none\",\n      normalize: false,\n      quantize: false,\n      precision: \"binary\",\n      model_output: nil\n    )\n      # Run tokenization\n      model_inputs = @tokenizer.(texts,\n        padding: true,\n        truncation: true\n      )\n      model_options = {}\n\n      if !model_output.nil?\n        model_options[:output_names] = Array(model_output)\n      elsif @model.instance_variable_get(:@output_names) == [\"token_embeddings\"] && pooling == \"mean\" && normalize\n        # optimization for previous revision of sentence-transformers/all-MiniLM-L6-v2\n        model_options[:output_names] = [\"sentence_embedding\"]\n        pooling = \"none\"\n        normalize = false\n      end\n\n      # Run model\n      outputs = @model.(model_inputs, **model_options)\n\n      # TODO improve\n      result =\n        if outputs.is_a?(Array)\n          # TODO show returned instead of all\n          output_names = @model.instance_variable_get(:@session).outputs.map { |v| v[:name] }\n          raise Error, \"unexpected outputs: #{output_names}\" if outputs.size != 1\n          outputs[0]\n        else\n          outputs.logits\n        end\n\n      case pooling\n      when \"none\"\n        # Skip pooling\n      when \"mean\"\n        result = Utils.mean_pooling(result, model_inputs[:attention_mask])\n      when \"cls\"\n        result = result.map(&:first)\n      else\n        # TODO raise ArgumentError in 2.0\n        raise Error, \"Pooling method '#{pooling}' not supported.\"\n      end\n\n      if normalize\n        result = Utils.normalize(result)\n      end\n\n      if quantize\n        result = quantize_embeddings(result, precision)\n      end\n\n      texts.is_a?(Array) ? result : result[0]\n    end\n  end\n\n  class ImageFeatureExtractionPipeline < Pipeline\n    def call(images)\n      prepared_images = prepare_images(images)\n      pixel_values = @processor.(prepared_images)[:pixel_values]\n      outputs = @model.({pixel_values: pixel_values})\n\n      result = outputs[0]\n      result\n    end\n  end\n\n  class AudioClassificationPipeline < Pipeline\n    def call(audio, top_k: nil)\n      single = !audio.is_a?(Array)\n\n      sampling_rate = @processor.feature_extractor.config[\"sampling_rate\"]\n      prepared_audios = prepare_audios(audio, sampling_rate)\n\n      id2label = @model.config[:id2label]\n\n      to_return = []\n      prepared_audios.each do |aud|\n        inputs = @processor.(aud)\n        output = @model.(inputs)\n        logits = output.logits[0]\n\n        scores = Utils.get_top_items(Utils.softmax(logits), top_k)\n\n        vals =\n          scores.map do |x|\n            {\n              label: id2label[x[0].to_s],\n              score: x[1]\n            }\n          end\n\n        if top_k == 1\n          to_return.concat(vals)\n        else\n          to_return << vals\n        end\n      end\n      !single || top_k == 1 ? to_return : to_return[0]\n    end\n  end\n\n  class ZeroShotAudioClassificationPipeline < Pipeline\n    def call(audio, candidate_labels, hypothesis_template: \"This is a sound of {}.\")\n      single = !audio.is_a?(Array)\n      if single\n        audio = [audio]\n      end\n\n      # Insert label into hypothesis template\n      texts = candidate_labels.map { |x| hypothesis_template.sub(\"{}\", x) }\n\n      # Run tokenization\n      text_inputs =\n        @tokenizer.(\n          texts,\n          padding: true,\n          truncation: true\n        )\n\n      sampling_rate = @processor.feature_extractor.config[\"sampling_rate\"]\n      prepared_audios = prepare_audios(audio, sampling_rate)\n\n      to_return = []\n      prepared_audios.each do |aud|\n        audio_inputs = @processor.(aud)\n\n        # Run model with both text and audio inputs\n        output = @model.(text_inputs.merge(audio_inputs))\n\n        # Compute softmax per audio\n        probs = Utils.softmax(output.logits_per_audio.data)\n\n        to_return <<\n          probs.map.with_index do |x, i|\n            {\n              label: candidate_labels[i],\n              score: x\n            }\n          end\n      end\n      single ? to_return[0] : to_return\n    end\n  end\n\n  class AutomaticSpeechRecognitionPipeline < Pipeline\n    def call(audio, **kwargs)\n      case @model.config[\"model_type\"]\n      when \"whisper\"\n        call_whisper(audio, **kwargs)\n      else\n        raise Error, \"AutomaticSpeechRecognitionPipeline does not support model type '#{@model.config[\"model_type\"]}'.\"\n      end\n    end\n\n    private\n\n    def call_whisper(audio, **kwargs)\n      raise Todo\n    end\n  end\n\n  class ImageToImagePipeline < Pipeline\n    def call(images)\n      prepared_images = prepare_images(images)\n      inputs = @processor.(prepared_images)\n      outputs = @model.(inputs)\n\n      to_return = []\n      outputs[0].each do |batch|\n        # TODO flatten first\n        output =\n          batch.map do |v|\n            v.map do |v2|\n              v2.map do |v3|\n                (v3.clamp(0, 1) * 255).round\n              end\n            end\n          end\n        to_return << Utils::RawImage.from_array(output).image\n      end\n\n      to_return.length > 1 ? to_return : to_return[0]\n    end\n  end\n\n  class DepthEstimationPipeline < Pipeline\n    def call(images)\n      prepared_images = prepare_images(images)\n\n      inputs = @processor.(prepared_images)\n      predicted_depth = @model.(inputs)[0]\n\n      to_return = []\n      prepared_images.length.times do |i|\n        prediction = Utils.interpolate(predicted_depth[i], prepared_images[i].size.reverse, \"bilinear\", false)\n        max_prediction = Utils.max(prediction.flatten)[0]\n        formatted =\n          prediction.map do |v|\n            v.map do |v2|\n              v2.map do |v3|\n                (v3 * 255 / max_prediction).round\n              end\n            end\n          end\n        to_return << {\n          predicted_depth: predicted_depth[i],\n          depth: Utils::RawImage.from_array(formatted).image\n        }\n      end\n      to_return.length > 1 ? to_return : to_return[0]\n    end\n  end\n\n  class EmbeddingPipeline < FeatureExtractionPipeline\n    def call(\n      texts,\n      pooling: \"mean\",\n      normalize: true,\n      model_output: nil\n    )\n      super(texts, pooling:, normalize:, model_output:)\n    end\n  end\n\n  class RerankingPipeline < Pipeline\n    def call(\n      query,\n      documents,\n      return_documents: false,\n      top_k: nil\n    )\n      model_inputs = @tokenizer.([query] * documents.size,\n        text_pair: documents,\n        padding: true,\n        truncation: true\n      )\n\n      outputs = @model.(model_inputs)\n\n      result =\n        Utils.sigmoid(outputs[0].map(&:first))\n          .map.with_index { |s, i| {doc_id: i, score: s} }\n          .sort_by { |v| -v[:score] }\n\n      if return_documents\n        result.each do |v|\n          v[:text] = documents[v[:doc_id]]\n        end\n      end\n\n      top_k ? result.first(top_k) : result\n    end\n  end\n\n  SUPPORTED_TASKS = {\n    \"text-classification\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: TextClassificationPipeline,\n      model: AutoModelForSequenceClassification,\n      default: {\n        model: \"Xenova/distilbert-base-uncased-finetuned-sst-2-english\"\n      },\n      type: \"text\"\n    },\n    \"token-classification\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: TokenClassificationPipeline,\n      model: AutoModelForTokenClassification,\n      default: {\n        model: \"Xenova/bert-base-multilingual-cased-ner-hrl\"\n      },\n      type: \"text\"\n    },\n    \"question-answering\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: QuestionAnsweringPipeline,\n      model: AutoModelForQuestionAnswering,\n      default: {\n        model: \"Xenova/distilbert-base-cased-distilled-squad\"\n      },\n      type: \"text\"\n    },\n    \"fill-mask\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: FillMaskPipeline,\n      model: AutoModelForMaskedLM,\n      default: {\n        model: \"Xenova/bert-base-uncased\"\n      },\n      type: \"text\"\n    },\n    \"summarization\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: SummarizationPipeline,\n      model: AutoModelForSeq2SeqLM,\n      default: {\n        model: \"Xenova/distilbart-cnn-6-6\"\n      },\n      type: \"text\"\n    },\n    \"translation\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: TranslationPipeline,\n      model: AutoModelForSeq2SeqLM,\n      default: {\n        model: \"Xenova/t5-small\"\n      },\n      type: \"text\"\n    },\n    \"text2text-generation\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: Text2TextGenerationPipeline,\n      model: AutoModelForSeq2SeqLM,\n      default: {\n        model: \"Xenova/flan-t5-small\"\n      },\n      type: \"text\"\n    },\n    \"text-generation\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: TextGenerationPipeline,\n      model: AutoModelForCausalLM,\n      default: {\n        model: \"Xenova/gpt2\"\n      },\n      type: \"text\"\n    },\n    \"zero-shot-classification\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: ZeroShotClassificationPipeline,\n      model: AutoModelForSequenceClassification,\n      default: {\n        model: \"Xenova/distilbert-base-uncased-mnli\"\n      },\n      type: \"text\"\n    },\n    \"audio-classification\" => {\n      pipeline: AudioClassificationPipeline,\n      model: AutoModelForAudioClassification,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/wav2vec2-base-superb-ks\"\n      },\n      type: \"audio\"\n    },\n    # TODO\n    # \"zero-shot-audio-classification\" => {\n    #   tokenizer: AutoTokenizer,\n    #   pipeline: ZeroShotAudioClassificationPipeline,\n    #   model: AutoModel,\n    #   processor: AutoProcessor,\n    #   default: {\n    #      model: \"Xenova/clap-htsat-unfused\"\n    #   },\n    #   type: \"multimodal\"\n    # },\n    # TODO\n    # \"automatic-speech-recognition\" => {\n    #   tokenizer: AutoTokenizer,\n    #   pipeline: AutomaticSpeechRecognitionPipeline,\n    #   model: [AutoModelForSpeechSeq2Seq, AutoModelForCTC],\n    #   processor: AutoProcessor,\n    #   default: {\n    #     model: \"Xenova/whisper-tiny.en\"\n    #   },\n    #   type: \"multimodal\"\n    # },\n    \"text-to-audio\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: TextToAudioPipeline,\n      model: [AutoModelForTextToWaveform, AutoModelForTextToSpectrogram],\n      processor: [AutoProcessor, nil],\n      default: {\n        model: \"Xenova/speecht5_tts\"\n      },\n      type: \"text\"\n    },\n    \"image-to-text\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: ImageToTextPipeline,\n      model: AutoModelForVision2Seq,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/vit-gpt2-image-captioning\"\n      },\n      type: \"multimodal\"\n    },\n    \"image-classification\" => {\n      pipeline: ImageClassificationPipeline,\n      model: AutoModelForImageClassification,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/vit-base-patch16-224\"\n      },\n      type: \"multimodal\"\n    },\n    \"image-segmentation\" => {\n      pipeline: ImageSegmentationPipeline,\n      model: [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation],\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/detr-resnet-50-panoptic\"\n      },\n      type: \"multimodal\"\n    },\n    \"zero-shot-image-classification\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: ZeroShotImageClassificationPipeline,\n      model: AutoModel,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/clip-vit-base-patch32\"\n      },\n      type: \"multimodal\"\n    },\n    \"object-detection\" => {\n      pipeline: ObjectDetectionPipeline,\n      model: AutoModelForObjectDetection,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/detr-resnet-50\"\n      },\n      type: \"multimodal\"\n    },\n    \"zero-shot-object-detection\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: ZeroShotObjectDetectionPipeline,\n      model: AutoModelForZeroShotObjectDetection,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/owlvit-base-patch32\"\n      },\n      type: \"multimodal\"\n    },\n    \"document-question-answering\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: DocumentQuestionAnsweringPipeline,\n      model: AutoModelForDocumentQuestionAnswering,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/donut-base-finetuned-docvqa\"\n      },\n      type: \"multimodal\"\n    },\n    \"image-to-image\" => {\n      pipeline: ImageToImagePipeline,\n      model: AutoModelForImageToImage,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/swin2SR-classical-sr-x2-64\"\n      },\n      type: \"image\"\n    },\n    \"depth-estimation\" => {\n      pipeline: DepthEstimationPipeline,\n      model: AutoModelForDepthEstimation,\n      processor: AutoProcessor,\n      default: {\n        model: \"Xenova/dpt-large\"\n      },\n      type: \"image\"\n    },\n    \"feature-extraction\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: FeatureExtractionPipeline,\n      model: AutoModel,\n      default: {\n        model: \"Xenova/all-MiniLM-L6-v2\"\n      },\n      type: \"text\"\n    },\n    \"image-feature-extraction\" => {\n      processor: AutoProcessor,\n      pipeline: ImageFeatureExtractionPipeline,\n      model: [AutoModelForImageFeatureExtraction, AutoModel],\n      default: {\n        model: \"Xenova/vit-base-patch16-224\"\n      },\n      type: \"image\"\n    },\n    \"embedding\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: EmbeddingPipeline,\n      model: AutoModel,\n      default: {\n        model: \"sentence-transformers/all-MiniLM-L6-v2\"\n      },\n      type: \"text\"\n    },\n    \"reranking\" => {\n      tokenizer: AutoTokenizer,\n      pipeline: RerankingPipeline,\n      model: AutoModel,\n      default: {\n        model: \"mixedbread-ai/mxbai-rerank-base-v1\"\n      },\n      type: \"text\"\n    }\n  }\n\n  TASK_ALIASES = {\n    \"sentiment-analysis\" => \"text-classification\",\n    \"ner\" => \"token-classification\",\n    \"text-to-speech\" => \"text-to-audio\"\n  }\n\n  DEFAULT_PROGRESS_CALLBACK = lambda do |msg|\n    stream = $stderr\n    tty = stream.tty?\n    width = tty ? stream.winsize[1] : 80\n    width = 80 if width == 0\n\n    if msg[:status] == \"progress\" && tty\n      stream.print \"\\r#{Utils::Hub.display_progress(msg[:file], width, msg[:size], msg[:total_size])}\"\n    elsif msg[:status] == \"done\" && !msg[:cache_hit]\n      if tty\n        stream.puts\n      else\n        stream.puts Utils::Hub.display_progress(msg[:file], width, 1, 1)\n      end\n    end\n  end\n\n  NO_DEFAULT = Object.new\n\n  class << self\n    def pipeline(\n      task,\n      model = nil,\n      quantized: NO_DEFAULT,\n      progress_callback: DEFAULT_PROGRESS_CALLBACK,\n      config: nil,\n      cache_dir: nil,\n      local_files_only: false,\n      revision: \"main\",\n      device: nil,\n      dtype: nil,\n      model_file_name: nil,\n      session_options: {}\n    )\n      # Apply aliases\n      task = TASK_ALIASES[task] || task\n\n      if quantized == NO_DEFAULT\n        # TODO no quantization by default in 2.0\n        quantized = [\"text-classification\", \"token-classification\", \"question-answering\", \"feature-extraction\"].include?(task)\n      end\n\n      # Get pipeline info\n      pipeline_info = SUPPORTED_TASKS[task.split(\"_\", 1)[0]]\n      if !pipeline_info\n        raise Error, \"Unsupported pipeline: #{task}. Must be one of #{SUPPORTED_TASKS.keys}\"\n      end\n\n      # Use model if specified, otherwise, use default\n      if !model\n        model = pipeline_info[:default][:model]\n        warn \"No model specified. Using default model: #{model.inspect}.\"\n      end\n\n      pretrained_options = {\n        quantized:,\n        progress_callback:,\n        config:,\n        cache_dir:,\n        local_files_only:,\n        revision:,\n        device:,\n        dtype:,\n        model_file_name:,\n        session_options:\n      }\n\n      classes = {\n        tokenizer: pipeline_info[:tokenizer],\n        model: pipeline_info[:model],\n        processor: pipeline_info[:processor]\n      }\n\n      # Load model, tokenizer, and processor (if they exist)\n      results = load_items(classes, model, pretrained_options)\n      results[:task] = task\n\n      # for previous revision of sentence-transformers/all-MiniLM-L6-v2\n      if model == \"sentence-transformers/all-MiniLM-L6-v2\" && results[:model].instance_variable_get(:@session).outputs.any? { |v| v[:name] == \"token_embeddings\" }\n        results[:model].instance_variable_set(:@output_names, [\"token_embeddings\"])\n      end\n\n      Utils.dispatch_callback(progress_callback, {\n        status: \"ready\",\n        task: task,\n        model: model\n      })\n\n      pipeline_class = pipeline_info.fetch(:pipeline)\n      pipeline_class.new(**results)\n    end\n\n    private\n\n    def load_items(mapping, model, pretrained_options)\n      result = {}\n\n      mapping.each do |name, cls|\n        next if !cls\n\n        if cls.is_a?(Array)\n          e = nil\n          cls.each do |c|\n            begin\n              result[name] = c.from_pretrained(model, **pretrained_options)\n            rescue => err\n              e = err\n            end\n          end\n          raise e unless result[name]\n        else\n          result[name] = cls.from_pretrained(model, **pretrained_options)\n        end\n      end\n\n      result\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/processors.rb",
    "content": "module Informers\n  class FeatureExtractor\n    attr_reader :config\n\n    def initialize(config)\n      super()\n      @config = config\n    end\n  end\n\n  class ImageFeatureExtractor < FeatureExtractor\n    def initialize(config)\n      super(config)\n\n      @image_mean = @config[\"image_mean\"] || @config[\"mean\"]\n      @image_std = @config[\"image_std\"] || @config[\"std\"]\n\n      @resample = @config[\"resample\"] || 2 # 2 => bilinear\n      @do_rescale = @config.fetch(\"do_rescale\", true)\n      @rescale_factor = @config[\"rescale_factor\"] || (1 / 255.0)\n      @do_normalize = @config[\"do_normalize\"]\n\n      @do_resize = @config[\"do_resize\"]\n      @do_thumbnail = @config[\"do_thumbnail\"]\n      @size = @config[\"size\"]\n      @size_divisibility = @config[\"size_divisibility\"] || @config[\"size_divisor\"]\n\n      @do_center_crop = @config[\"do_center_crop\"]\n      @crop_size = @config[\"crop_size\"]\n      @do_convert_rgb = @config.fetch(\"do_convert_rgb\", true)\n      @do_crop_margin = @config[\"do_crop_margin\"]\n\n      @pad_size = @config[\"pad_size\"]\n      @do_pad = @config[\"do_pad\"]\n\n      if @do_pad && !@pad_size && @size && !@size[\"width\"].nil? && !@size[\"height\"].nil?\n        # Should pad, but no pad size specified\n        # We infer the pad size from the resize size\n        @pad_size = @size\n      end\n\n      @do_flip_channel_order = @config[\"do_flip_channel_order\"] || false\n    end\n\n    def thumbnail(image, size, resample = 2)\n      input_height = image.height\n      input_width = image.width\n\n      output_height = size[\"height\"]\n      output_width = size[\"width\"]\n\n      # We always resize to the smallest of either the input or output size.\n      height = [input_height, output_height].min\n      width = [input_width, output_width].min\n\n      if height == input_height && width == input_width\n        return image\n      end\n      if input_height > input_width\n        width = (input_width * height / input_height).floor\n      elsif input_width > input_height\n        height = (input_height * width / input_width).floor\n      end\n      image.resize(width, height, resample:)\n    end\n\n    def pad_image(\n      pixel_data,\n      img_dims,\n      pad_size,\n      mode: \"constant\",\n      center: false,\n      constant_values: 0\n    )\n      image_height, image_width, image_channels = img_dims\n\n      if pad_size.is_a?(Numeric)\n        padded_image_width = pad_size\n        padded_image_height = pad_size\n      else\n        padded_image_width = pad_size[:width] || pad_size[\"width\"]\n        padded_image_height = pad_size[:height] || pad_size[\"height\"]\n      end\n\n      # Only add padding if there is a difference in size\n      if padded_image_width != image_width || padded_image_height != image_height\n        padded_pixel_data = Array.new(padded_image_width * padded_image_height * image_channels)\n        if constant_values.is_a?(Array)\n          # Fill with constant values, cycling through the array\n          padded_pixel_data.length.times do |i|\n            padded_pixel_data[i] = constant_values[i % image_channels]\n          end\n        elsif constant_values != 0\n          padded_pixel_data.fill(constant_values)\n        end\n\n        left, top =\n          if center\n            [((padded_image_width - image_width) / 2.0).floor, ((padded_image_height - image_height) / 2.0).floor]\n          else\n            [0, 0]\n          end\n\n        # Copy the original image into the padded image\n        image_height.times do |i|\n          a = (i + top) * padded_image_width\n          b = i * image_width\n          image_width.times do |j|\n            c = (a + j + left) * image_channels\n            d = (b + j) * image_channels\n            image_channels.times do |k|\n              padded_pixel_data[c + k] = pixel_data[d + k]\n            end\n          end\n        end\n\n        if mode == \"symmetric\"\n          if center\n            raise Error, \"`center` padding is not supported when `mode` is set to `symmetric`.\"\n          end\n          h1 = image_height - 1\n          w1 = image_width - 1\n          padded_image_height.times do |i|\n            a = i * padded_image_width\n            b = Utils.calculate_reflect_offset(i, h1) * image_width\n\n            padded_image_width.times do |j|\n              next if i < image_height && j < image_width # Do not overwrite original image\n              c = (a + j) * image_channels\n              d = (b + Utils.calculate_reflect_offset(j, w1)) * image_channels\n\n              # Copy channel-wise\n              image_channels.times do |k|\n                padded_pixel_data[c + k] = pixel_data[d + k]\n              end\n            end\n          end\n        end\n\n        # Update pixel data and image dimensions\n        pixel_data = padded_pixel_data\n        img_dims = [padded_image_height, padded_image_width, image_channels]\n      end\n      [pixel_data, img_dims]\n    end\n\n    def rescale(pixel_data)\n      pixel_data.length.times do |i|\n        pixel_data[i] *= @rescale_factor\n      end\n    end\n\n    def get_resize_output_image_size(image, size)\n      src_width, src_height = image.size\n\n      if @do_thumbnail\n        # NOTE: custom logic for `Donut` models\n        height = size[\"height\"]\n        width = size[\"width\"]\n        shortest_edge = [height, width].min\n      elsif size.is_a?(Numeric)\n        shortest_edge = size\n        longest_edge = @config[\"max_size\"] || shortest_edge\n      elsif !size.nil?\n        # Extract known properties from `size`\n        shortest_edge = size[\"shortest_edge\"]\n        longest_edge = size[\"longest_edge\"]\n      end\n\n      if !shortest_edge.nil? || !longest_edge.nil?\n        # http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/\n        # Try resize so that shortest edge is `shortest_edge` (target)\n        short_resize_factor =\n          if shortest_edge.nil?\n            1 # If `shortest_edge` is not set, don't upscale\n          else\n            [shortest_edge / src_width.to_f, shortest_edge / src_height.to_f].max\n          end\n\n        new_width = src_width * short_resize_factor\n        new_height = src_height * short_resize_factor\n\n        # The new width and height might be greater than `longest_edge`, so\n        # we downscale again to ensure the largest dimension is `longest_edge`\n        long_resize_factor =\n          if longest_edge.nil?\n            1 # If `longest_edge` is not set, don't downscale\n          else\n            [longest_edge / new_width.to_f, longest_edge / new_height.to_f].min\n          end\n\n        # To avoid certain floating point precision issues, we round to 2 decimal places\n        final_width = (new_width * long_resize_factor).round(2).floor\n        final_height = (new_height * long_resize_factor).round(2).floor\n\n        if !@size_divisibility.nil?\n          raise Todo\n        end\n        [final_width, final_height]\n      elsif !size.nil? && !size[\"width\"].nil? && !size[\"height\"].nil?\n        new_width = size[\"width\"]\n        new_height = size[\"height\"]\n\n        if @config[\"keep_aspect_ratio\"] && @config[\"ensure_multiple_of\"]\n          raise Todo\n        end\n\n        [new_width, new_height]\n      else\n        raise Todo\n      end\n    end\n\n    def resize(image)\n      new_width, new_height = get_resize_output_image_size(image, @size)\n      image.resize(new_width, new_height, resample: @resample)\n    end\n\n    def preprocess(\n      image,\n      do_normalize: nil,\n      do_pad: nil,\n      do_convert_rgb: nil,\n      do_convert_grayscale: nil,\n      do_flip_channel_order: nil\n    )\n      if @do_crop_margin\n        # NOTE: Specific to nougat processors. This is done before resizing,\n        # and can be interpreted as a pre-preprocessing step.\n        image = crop_margin(image)\n      end\n\n      src_width, src_height = image.size # original image size\n\n      # Convert image to RGB if specified in config.\n      if !do_convert_rgb.nil? ? do_convert_rgb : @do_convert_rgb\n        image = image.rgb\n      elsif do_convert_grayscale\n        image = image.grayscale\n      end\n\n      # Resize all images\n      if @do_resize\n        image = resize(image)\n      end\n\n      # Resize the image using thumbnail method.\n      if @do_thumbnail\n        image = thumbnail(image, @size, @resample)\n      end\n\n      if @do_center_crop\n        if @crop_size.is_a?(Integer)\n          crop_width = @crop_size\n          crop_height = @crop_size\n        else\n          crop_width = @crop_size[\"width\"]\n          crop_height = @crop_size[\"height\"]\n        end\n        image = image.center_crop(crop_width, crop_height)\n      end\n\n      reshaped_input_size = [image.height, image.width]\n\n      # NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)\n      # occurs with data in the hwc format (height, width, channels),\n      # to emulate the behavior of the original Python code (w/ numpy).\n      pixel_data = image.data\n      img_dims = [image.height, image.width, image.channels]\n\n      if @do_rescale\n        rescale(pixel_data)\n      end\n\n      if !do_normalize.nil? ? do_normalize : @do_normalize\n        image_mean = @image_mean\n        if !@image_mean.is_a?(Array)\n          image_mean = new Array(image.channels) { image_mean }\n        end\n\n        image_std = @image_std\n        if !@image_std.is_a?(Array)\n          image_std = new Array(image.channels) { image_std }\n        end\n\n        if image_mean.length != image.channels || image_std.length != image.channels\n          raise Error, \"When set to arrays, the length of `image_mean` (#{image_mean.length}) and `image_std` (#{image_std.length}) must match the number of channels in the image (#{image.channels}).\"\n        end\n\n        i = 0\n        while i < pixel_data.length\n          image.channels.times do |j|\n            pixel_data[i + j] = (pixel_data[i + j] - image_mean[j]) / image_std[j]\n          end\n          i += image.channels\n        end\n      end\n\n      # do padding after rescaling/normalizing\n      if !do_pad.nil? ? do_pad : @do_pad\n        if @pad_size\n          padded = pad_image(pixel_data, [image.height, image.width, image.channels], @pad_size)\n          pixel_data, img_dims = padded # Update pixel data and image dimensions\n        elsif @size_divisibility\n          raise Todo\n        end\n      end\n\n      if !do_flip_channel_order.nil? ? do_flip_channel_order : @do_flip_channel_order\n        raise Todo\n      end\n\n      # convert to channel dimension format (hwc -> chw)\n      h, w, c = img_dims\n      pixel_values =\n        c.times.map do |ci|\n          h.times.map do |hi|\n            w.times.map do |wi|\n              index = (hi * w * c) + (wi * c) + ci\n              pixel_data[index]\n            end\n          end\n        end\n\n      {\n        original_size: [src_height, src_width],\n        reshaped_input_size: reshaped_input_size,\n        pixel_values: pixel_values\n      }\n    end\n\n    def call(images, *args)\n      if !images.is_a?(Array)\n        images = [images]\n      end\n\n      image_data = images.map { |x| preprocess(x) }\n\n      # Stack pixel values\n      pixel_values = Utils.stack(image_data.map { |x| x[:pixel_values] }, 0)\n\n      {\n        pixel_values: pixel_values,\n\n        # Original sizes of images\n        original_sizes: image_data.map { |x| x[:original_size] },\n\n        # Reshaped sizes of images, before padding or cropping\n        reshaped_input_sizes: image_data.map { |x| x[:reshaped_input_size] }\n      }\n    end\n  end\n\n  class CLIPFeatureExtractor < ImageFeatureExtractor\n  end\n\n  class DPTFeatureExtractor < ImageFeatureExtractor\n  end\n\n  class ViTFeatureExtractor < ImageFeatureExtractor\n  end\n\n  class OwlViTFeatureExtractor < ImageFeatureExtractor\n    def post_process_object_detection(*args)\n      Utils.post_process_object_detection(*args)\n    end\n  end\n\n  class Swin2SRImageProcessor < ImageFeatureExtractor\n    def pad_image(pixel_data, img_dims, pad_size, **options)\n      # NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.\n      # In other words, the image is padded so that its width and height are multiples of `padSize`.\n      image_height, image_width, _image_channels = img_dims\n\n      super(\n        pixel_data,\n        img_dims,\n        {\n          # NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already\n          # a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19).\n          # For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`.\n          width: image_width + (pad_size - image_width % pad_size) % pad_size,\n          height: image_height + (pad_size - image_height % pad_size) % pad_size\n        },\n        mode: \"symmetric\",\n        center: false,\n        constant_values: -1,\n        **options\n      )\n    end\n  end\n\n  class DonutFeatureExtractor < ImageFeatureExtractor\n    def pad_image(pixel_data, img_dims, pad_size, **options)\n      _image_height, _image_width, image_channels = img_dims\n\n      image_mean = @image_mean\n      if !image_mean.is_a?(Array)\n        image_mean = new Array(image_channels, image_mean)\n      end\n\n      image_std = @image_std\n      if !image_std.is_a?(Array)\n        image_std = new Array(image_channels, image_std)\n      end\n\n      constant_values = image_mean.map.with_index { |x, i| -x / image_std[i] }\n\n      super(\n        pixel_data,\n        img_dims,\n        pad_size,\n        center: true,\n        # Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.\n        # For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451\n        constant_values: constant_values,\n        **options\n      )\n    end\n  end\n\n  class DetrFeatureExtractor < ImageFeatureExtractor\n    def call(images)\n      result = super(images)\n\n      # TODO support differently-sized images, for now assume all images are the same size.\n      # TODO support different mask sizes (not just 64x64)\n      # Currently, just fill pixel mask with 1s\n      mask_size = [result[:pixel_values].size, 64, 64]\n      pixel_mask =\n        mask_size[0].times.map do\n          mask_size[1].times.map do\n            mask_size[2].times.map do\n              1\n            end\n          end\n        end\n\n      result.merge(pixel_mask: pixel_mask)\n    end\n\n    def post_process_object_detection(*args)\n      Utils.post_process_object_detection(*args)\n    end\n\n    def remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels)\n      mask_probs_item = []\n      pred_scores_item = []\n      pred_labels_item = []\n\n      class_logits.size.times do |j|\n        cls = class_logits[j]\n        mask = mask_logits[j]\n\n        pred_label = Utils.max(cls)[1]\n        if pred_label == num_labels\n          # Is the background, so we ignore it\n          next\n        end\n\n        scores = Utils.softmax(cls)\n        pred_score = scores[pred_label]\n        if pred_score > object_mask_threshold\n          mask_probs_item << mask\n          pred_scores_item << pred_score\n          pred_labels_item << pred_label\n        end\n      end\n\n      [mask_probs_item, pred_scores_item, pred_labels_item]\n    end\n\n    def check_segment_validity(\n      mask_labels,\n      mask_probs,\n      k,\n      mask_threshold = 0.5,\n      overlap_mask_area_threshold = 0.8\n    )\n      # mask_k is a 1D array of indices, indicating where the mask is equal to k\n      mask_k = []\n      mask_k_area = 0\n      original_area = 0\n\n      mask_probs_k_data = mask_probs[k].flatten\n\n      # Compute the area of all the stuff in query k\n      mask_labels.length.times do |i|\n        if mask_labels[i] == k\n          mask_k << i\n          mask_k_area += 1\n        end\n\n        if mask_probs_k_data[i] >= mask_threshold\n          original_area += 1\n        end\n      end\n      mask_exists = mask_k_area > 0 && original_area > 0\n\n      # Eliminate disconnected tiny segments\n      if mask_exists\n        # Perform additional check\n        area_ratio = mask_k_area / original_area\n        mask_exists = area_ratio > overlap_mask_area_threshold\n      end\n\n      [mask_exists, mask_k]\n    end\n\n    def compute_segments(\n      mask_probs,\n      pred_scores,\n      pred_labels,\n      mask_threshold,\n      overlap_mask_area_threshold,\n      label_ids_to_fuse = nil,\n      target_size = nil\n    )\n      height, width = target_size || Utils.dims(mask_probs[0])\n\n      segmentation = Array.new(height * width)\n      segments = []\n\n      # 1. If target_size is not null, we need to resize the masks to the target size\n      if !target_size.nil?\n        # resize the masks to the target size\n        mask_probs.length.times do |i|\n          mask_probs[i] = Utils.interpolate(mask_probs[i], target_size, \"bilinear\", false)\n        end\n      end\n\n      # 2. Weigh each mask by its prediction score\n      # NOTE: `mask_probs` is updated in-place\n      #\n      # Temporary storage for the best label/scores for each pixel ([height, width]):\n      mask_labels = Array.new(mask_probs[0].flatten.length)\n      best_scores = Array.new(mask_probs[0].flatten.length, 0)\n\n      mask_probs.length.times do |i|\n        score = pred_scores[i]\n\n        mask_probs_i_data = mask_probs[i].flatten\n        mask_probs_i_dims = Utils.dims(mask_probs[i])\n\n        mask_probs_i_data.length.times do |j|\n          mask_probs_i_data[j] *= score\n          if mask_probs_i_data[j] > best_scores[j]\n            mask_labels[j] = i\n            best_scores[j] = mask_probs_i_data[j]\n          end\n        end\n\n        mask_probs[i] = Utils.reshape(mask_probs_i_data, mask_probs_i_dims)\n      end\n\n      current_segment_id = 0\n\n      # stuff_memory_list = {}\n      pred_labels.length.times do |k|\n        pred_class = pred_labels[k]\n\n        # TODO add `should_fuse`\n        # should_fuse = label_ids_to_fuse.include?(pred_class)\n\n        # Check if mask exists and large enough to be a segment\n        mask_exists, mask_k = check_segment_validity(\n          mask_labels,\n          mask_probs,\n          k,\n          mask_threshold,\n          overlap_mask_area_threshold\n        )\n\n        if !mask_exists\n          # Nothing to see here\n          next\n        end\n\n        current_segment_id += 1\n\n        # Add current object segment to final segmentation map\n        mask_k.each do |index|\n          segmentation[index] = current_segment_id\n        end\n\n        segments << {\n          id: current_segment_id,\n          label_id: pred_class,\n          score: pred_scores[k]\n        }\n      end\n\n      segmentation = Utils.reshape(segmentation, [height, width])\n\n      [segmentation, segments]\n    end\n\n    def post_process_panoptic_segmentation(\n      outputs,\n      threshold: 0.5,\n      mask_threshold: 0.5,\n      overlap_mask_area_threshold: 0.8,\n      label_ids_to_fuse: nil,\n      target_sizes: nil\n    )\n      if label_ids_to_fuse.nil?\n        warn \"`label_ids_to_fuse` unset. No instance will be fused.\"\n        label_ids_to_fuse = Set.new\n      end\n\n      class_queries_logits = outputs[:logits] # [batch_size, num_queries, num_classes+1]\n      masks_queries_logits = outputs[:pred_masks] # [batch_size, num_queries, height, width]\n\n      mask_probs = Utils.sigmoid(masks_queries_logits) # [batch_size, num_queries, height, width]\n\n      batch_size, _num_queries, num_labels = class_queries_logits.size, class_queries_logits[0].size, class_queries_logits[0][0].size\n      num_labels -= 1 # Remove last class (background)\n\n      if !target_sizes.nil? && target_sizes.length != batch_size\n        raise Error, \"Make sure that you pass in as many target sizes as the batch dimension of the logits\"\n      end\n\n      to_return = []\n      batch_size.times do |i|\n        target_size = !target_sizes.nil? ? target_sizes[i] : nil\n\n        class_logits = class_queries_logits[i]\n        mask_logits = mask_probs[i]\n\n        mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels)\n\n        if pred_labels_item.length == 0\n          raise Todo\n        end\n\n        # Get segmentation map and segment information of batch item\n        segmentation, segments = compute_segments(\n          mask_probs_item,\n          pred_scores_item,\n          pred_labels_item,\n          mask_threshold,\n          overlap_mask_area_threshold,\n          label_ids_to_fuse,\n          target_size\n        )\n\n        to_return << {\n          segmentation: segmentation,\n          segments_info: segments\n        }\n      end\n\n      to_return\n    end\n  end\n\n  module Utils\n    def self.center_to_corners_format(v)\n      centerX, centerY, width, height = v\n      [\n        centerX - width / 2.0,\n        centerY - height / 2.0,\n        centerX + width / 2.0,\n        centerY + height / 2.0\n      ]\n    end\n\n    def self.post_process_object_detection(outputs, threshold = 0.5, target_sizes = nil, is_zero_shot = false)\n      out_logits = outputs[:logits]\n      out_bbox = outputs[:pred_boxes]\n      batch_size, num_boxes, num_classes = out_logits.size, out_logits[0].size, out_logits[0][0].size\n\n      if !target_sizes.nil? && target_sizes.length != batch_size\n        raise Error, \"Make sure that you pass in as many target sizes as the batch dimension of the logits\"\n      end\n      to_return = []\n      batch_size.times do |i|\n        target_size = !target_sizes.nil? ? target_sizes[i] : nil\n        info = {\n          boxes: [],\n          classes: [],\n          scores: []\n        }\n        logits = out_logits[i]\n        bbox = out_bbox[i]\n\n        num_boxes.times do |j|\n          logit = logits[j]\n\n          indices = []\n          if is_zero_shot\n            # Get indices of classes with high enough probability\n            probs = Utils.sigmoid(logit)\n            probs.length.times do |k|\n              if probs[k] > threshold\n                indices << k\n              end\n            end\n          else\n            # Get most probable class\n            max_index = Utils.max(logit)[1]\n\n            if max_index == num_classes - 1\n              # This is the background class, skip it\n              next\n            end\n            indices << max_index\n\n            # Compute softmax over classes\n            probs = Utils.softmax(logit)\n          end\n\n          indices.each do |index|\n            box = bbox[j]\n\n            # convert to [x0, y0, x1, y1] format\n            box = center_to_corners_format(box)\n            if !target_size.nil?\n              box = box.map.with_index { |x, i| x * target_size[(i + 1) % 2] }\n            end\n\n            info[:boxes] << box\n            info[:classes] << index\n            info[:scores] << probs[index]\n          end\n        end\n        to_return << info\n      end\n      to_return\n    end\n  end\n\n  class WhisperFeatureExtractor < FeatureExtractor\n    def initialize(config)\n      super(config)\n\n      raise Todo\n    end\n\n    def _extract_fbank_features(waveform)\n      raise Todo\n    end\n\n    def call(audio)\n      raise Todo\n    end\n  end\n\n  class Wav2Vec2FeatureExtractor < FeatureExtractor\n    def _zero_mean_unit_var_norm(input_values)\n      sum = input_values.sum\n      mean = sum / input_values.length.to_f\n      variance = input_values.sum { |b| (b - mean) ** 2 } / input_values.length.to_f\n      input_values.map { |x| (x - mean) / Math.sqrt(variance + 1e-7) }\n    end\n\n    def call(audio)\n      # TODO\n      # validate_audio_inputs(audio, 'Wav2Vec2FeatureExtractor')\n\n      input_values = audio\n\n      # zero-mean and unit-variance normalization\n      if @config[\"do_normalize\"]\n        input_values = _zero_mean_unit_var_norm(input_values)\n      end\n\n      # TODO: allow user to pass in attention mask\n      {\n        input_values: [input_values],\n        attention_mask: [Array.new(input_values.length, 1)]\n      }\n    end\n  end\n\n  class ClapFeatureExtractor < FeatureExtractor\n    def initialize(config)\n      super(config)\n\n      # TODO\n    end\n\n    def call(audio, max_length: nil)\n      raise Todo\n    end\n  end\n\n  class Processor\n    attr_reader :feature_extractor\n\n    def initialize(feature_extractor)\n      @feature_extractor = feature_extractor\n    end\n\n    def call(input, *args)\n      @feature_extractor.(input, *args)\n    end\n  end\n\n  class AutoProcessor\n    FEATURE_EXTRACTOR_CLASS_MAPPING = {\n      \"ViTFeatureExtractor\" => ViTFeatureExtractor,\n      \"OwlViTFeatureExtractor\" => OwlViTFeatureExtractor,\n      \"CLIPFeatureExtractor\" => CLIPFeatureExtractor,\n      \"DPTFeatureExtractor\" => DPTFeatureExtractor,\n      \"DetrFeatureExtractor\" => DetrFeatureExtractor,\n      \"Swin2SRImageProcessor\" => Swin2SRImageProcessor,\n      \"DonutFeatureExtractor\" => DonutFeatureExtractor,\n      \"WhisperFeatureExtractor\" => WhisperFeatureExtractor,\n      \"Wav2Vec2FeatureExtractor\" => Wav2Vec2FeatureExtractor,\n      \"ClapFeatureExtractor\" => ClapFeatureExtractor\n    }\n\n    PROCESSOR_CLASS_MAPPING = {}\n\n    def self.from_pretrained(\n      pretrained_model_name_or_path,\n      progress_callback: nil,\n      config: nil,\n      cache_dir: nil,\n      local_files_only: false,\n      revision: \"main\",\n      **kwargs\n    )\n      preprocessor_config = config || Utils::Hub.get_model_json(pretrained_model_name_or_path, \"preprocessor_config.json\", true,\n        progress_callback:,\n        config:,\n        cache_dir:,\n        local_files_only:,\n        revision:\n      )\n\n      # Determine feature extractor class\n      # TODO: Ensure backwards compatibility with old configs\n      key = preprocessor_config[\"feature_extractor_type\"] || preprocessor_config[\"image_processor_type\"]\n      feature_extractor_class = FEATURE_EXTRACTOR_CLASS_MAPPING[key]\n\n      if !feature_extractor_class\n        if preprocessor_config[\"size\"]\n          # Assume ImageFeatureExtractor\n          warn \"Feature extractor type #{key.inspect} not found, assuming ImageFeatureExtractor due to size parameter in config.\"\n          feature_extractor_class = ImageFeatureExtractor\n        else\n          raise Error, \"Unknown Feature Extractor type: #{key}\"\n        end\n      end\n\n      # If no associated processor class, use default\n      processor_class = PROCESSOR_CLASS_MAPPING[preprocessor_config[\"processor_class\"]] || Processor\n\n      # Instantiate processor and feature extractor\n      feature_extractor = feature_extractor_class.new(preprocessor_config)\n      processor_class.new(feature_extractor)\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/tokenizers.rb",
    "content": "module Informers\n  class PreTrainedTokenizer\n    attr_reader :mask_token, :mask_token_id, :sep_token_id\n\n    def initialize(tokenizer_json, tokenizer_config)\n      super()\n\n      @tokenizer_config = tokenizer_config\n\n      @tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)\n\n      # Add added_tokens to model\n      @special_tokens = []\n      @all_special_ids = []\n\n      @added_tokens = []\n      @tokenizer.added_tokens_decoder.each do |id, token|\n        @added_tokens << token\n\n        if token.special\n          @special_tokens << token.content\n          @all_special_ids << id\n        end\n      end\n\n      # Update additional_special_tokens\n      @additional_special_tokens = tokenizer_config[\"additional_special_tokens\"] || []\n      @special_tokens.concat(@additional_special_tokens)\n\n      @mask_token = get_token(\"mask_token\")\n      @mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token\n\n      @sep_token = get_token(\"sep_token\")\n      @sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token\n\n      @model_max_length = tokenizer_config[\"model_max_length\"]\n\n      # for donut-base-finetuned-docvqa\n      if @model_max_length && @model_max_length > (1 << 63)\n        @model_max_length = 1 << 63\n      end\n    end\n\n    def get_token(*keys)\n      keys.each do |key|\n        item = @tokenizer_config[key]\n        if !item\n          next\n        end\n\n        if item.is_a?(Hash)\n          if item[\"__type\"] == \"AddedToken\"\n            return item[\"content\"]\n          else\n            raise Error, \"Unknown token: #{item}\"\n          end\n        else\n          return item\n        end\n      end\n\n      nil\n    end\n\n    def call(\n      text,\n      text_pair: nil,\n      add_special_tokens: true,\n      padding: false,\n      truncation: nil,\n      max_length: nil,\n      return_tensor: true,\n      return_token_type_ids: true, # TODO change default\n      return_offsets: false\n    )\n      is_batched = text.is_a?(Array)\n\n      if is_batched\n        if text.length == 0\n          raise Error, \"text array must be non-empty\"\n        end\n\n        if !text_pair.nil?\n          if !text_pair.is_a?(Array)\n            raise Error, \"text_pair must also be an array\"\n          elsif text.length != text_pair.length\n            raise Error, \"text and text_pair must have the same length\"\n          end\n        end\n      end\n\n      if padding\n        @tokenizer.enable_padding\n      else\n        @tokenizer.no_padding\n      end\n\n      if truncation\n        @tokenizer.enable_truncation(max_length || @model_max_length)\n      else\n        @tokenizer.no_truncation\n      end\n\n      if is_batched\n        input = text_pair ? text.zip(text_pair) : text\n        encoded = @tokenizer.encode_batch(input, add_special_tokens: add_special_tokens)\n      else\n        encoded = [@tokenizer.encode(text, text_pair, add_special_tokens: add_special_tokens)]\n      end\n\n      result = {input_ids: encoded.map(&:ids), attention_mask: encoded.map(&:attention_mask)}\n      if return_token_type_ids\n        result[:token_type_ids] = encoded.map(&:type_ids)\n      end\n      if return_offsets\n        result[:offsets] = encoded.map(&:offsets)\n      end\n      result\n    end\n\n    def decode(tokens, skip_special_tokens:)\n      @tokenizer.decode(tokens, skip_special_tokens: skip_special_tokens)\n    end\n\n    def convert_tokens_to_string(tokens)\n      @tokenizer.decoder.decode(tokens)\n    end\n\n    def convert_tokens_to_ids(tokens)\n      tokens.map { |t| @tokenizer.token_to_id(t) }\n    end\n\n    def id_to_token(id)\n      @tokenizer.id_to_token(id)\n    end\n\n    def batch_decode(batch, **decode_args)\n      @tokenizer.decode_batch(batch, **decode_args)\n    end\n\n    def padding_side=(side)\n      @tokenizer.enable_padding(direction: side)\n    end\n  end\n\n  class BertTokenizer < PreTrainedTokenizer\n    # TODO\n    # self.return_token_type_ids = true\n  end\n\n  class DebertaV2Tokenizer < PreTrainedTokenizer\n    # TODO\n    # self.return_token_type_ids = true\n  end\n\n  class DistilBertTokenizer < PreTrainedTokenizer\n  end\n\n  class T5Tokenizer < PreTrainedTokenizer\n  end\n\n  class GPT2Tokenizer < PreTrainedTokenizer\n    # _default_chat_template = `{% for message in messages %}\" \"{{ message.content }}{{ eos_token }}\" \"{% endfor %}`\n  end\n\n  class BartTokenizer < PreTrainedTokenizer\n  end\n\n  class RobertaTokenizer < PreTrainedTokenizer\n  end\n\n  class XLMRobertaTokenizer < PreTrainedTokenizer\n  end\n\n  class MPNetTokenizer < PreTrainedTokenizer\n  end\n\n  class CLIPTokenizer < PreTrainedTokenizer\n  end\n\n  class NllbTokenizer < PreTrainedTokenizer\n    attr_reader :language_regex, :language_codes, :lang_to_token\n\n    def initialize(tokenizer_json, tokenizer_config)\n      super(tokenizer_json, tokenizer_config)\n\n      @language_regex = /^[a-z]{3}_[A-Z][a-z]{3}$/\n      @language_codes = @special_tokens.filter { |x| @language_regex.match?(x) }\n      @lang_to_token = ->(x) { x } # Identity function\n    end\n\n    def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)\n      Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)\n    end\n  end\n\n  class M2M100Tokenizer < PreTrainedTokenizer\n    attr_reader :language_regex, :language_codes, :lang_to_token\n\n    def initialize(tokenizer_json, tokenizer_config)\n      super(tokenizer_json, tokenizer_config)\n\n      @language_regex = /^__[a-z]{2,3}__$/\n      @language_codes = @special_tokens\n        .filter { |x| @language_regex.match?(x) }\n        .map { |x| x.slice(2, -2) }\n      @lang_to_token = ->(x) { \"__#{x}__\" }\n    end\n\n    def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)\n      Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)\n    end\n  end\n\n  module Utils\n    def self._build_translation_inputs(slf, raw_inputs, tokenizer_options, generate_kwargs)\n      if !slf.respond_to?(:language_codes) || !slf.language_codes.is_a?(Array)\n        raise Error, \"Tokenizer must have `language_codes` attribute set and it should be an array of language ids.\"\n      end\n      if !slf.respond_to?(:language_regex) || !slf.language_regex.is_a?(Regexp)\n        raise Error, \"Tokenizer must have `language_regex` attribute set and it should be a regular expression.\"\n      end\n      if !slf.respond_to?(:lang_to_token) || !slf.lang_to_token.respond_to?(:call)\n        raise Error, \"Tokenizer must have `lang_to_token` attribute set and it should be a function.\"\n      end\n      src_lang_token = generate_kwargs[:src_lang]\n      tgt_lang_token = generate_kwargs[:tgt_lang]\n\n      if !slf.language_codes.include?(tgt_lang_token)\n        raise Error, \"Target language code #{tgt_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(\", \")}\"\n      end\n\n      if !src_lang_token.nil?\n        # Check that the source language is valid:\n        if !slf.language_codes.include?(src_lang_token)\n          raise Error, \"Source language code #{src_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(\", \")}\"\n        end\n      end\n\n      # Override the `forced_bos_token_id` to force the correct language\n      generate_kwargs[\"forced_bos_token_id\"] = slf.convert_tokens_to_ids([slf.lang_to_token.(tgt_lang_token)])[0]\n\n      slf.(raw_inputs, **tokenizer_options)\n    end\n  end\n\n  class SpeechT5Tokenizer < PreTrainedTokenizer\n  end\n\n  class AutoTokenizer\n    TOKENIZER_CLASS_MAPPING = {\n      \"T5Tokenizer\" => T5Tokenizer,\n      \"BertTokenizer\" => BertTokenizer,\n      \"DebertaV2Tokenizer\" => DebertaV2Tokenizer,\n      \"DistilBertTokenizer\" => DistilBertTokenizer,\n      \"BartTokenizer\" => BartTokenizer,\n      \"RobertaTokenizer\" => RobertaTokenizer,\n      \"XLMRobertaTokenizer\" => XLMRobertaTokenizer,\n      \"MPNetTokenizer\" => MPNetTokenizer,\n      \"CLIPTokenizer\" => CLIPTokenizer,\n      \"GPT2Tokenizer\" => GPT2Tokenizer,\n      \"NllbTokenizer\" => NllbTokenizer,\n      \"M2M100Tokenizer\" => M2M100Tokenizer,\n      \"SpeechT5Tokenizer\" => SpeechT5Tokenizer,\n      \"PreTrainedTokenizer\" => PreTrainedTokenizer\n    }\n\n    def self.from_pretrained(\n      pretrained_model_name_or_path,\n      quantized: true,\n      progress_callback: nil,\n      config: nil,\n      cache_dir: nil,\n      local_files_only: false,\n      revision: \"main\",\n      legacy: nil,\n      **kwargs\n    )\n      tokenizer_json, tokenizer_config = load_tokenizer(\n        pretrained_model_name_or_path,\n        quantized:,\n        progress_callback:,\n        config:,\n        cache_dir:,\n        local_files_only:,\n        revision:,\n        legacy:\n      )\n\n      # Some tokenizers are saved with the \"Fast\" suffix, so we remove that if present.\n      tokenizer_name = tokenizer_config[\"tokenizer_class\"]&.delete_suffix(\"Fast\") || \"PreTrainedTokenizer\"\n\n      cls = TOKENIZER_CLASS_MAPPING[tokenizer_name]\n      if !cls\n        warn \"Unknown tokenizer class #{tokenizer_name.inspect}, attempting to construct from base class.\"\n        cls = PreTrainedTokenizer\n      end\n      cls.new(tokenizer_json, tokenizer_config)\n    end\n\n    def self.load_tokenizer(pretrained_model_name_or_path, **options)\n      info = [\n        Utils::Hub.get_model_file(pretrained_model_name_or_path, \"tokenizer.json\", true, **options),\n        Utils::Hub.get_model_json(pretrained_model_name_or_path, \"tokenizer_config.json\", true, **options)\n      ]\n\n      # Override legacy option if `options.legacy` is not null\n      if !options[:legacy].nil?\n        info[1][\"legacy\"] = options[:legacy]\n      end\n      info\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/audio.rb",
    "content": "module Informers\n  module Utils\n    def self.read_audio(input, sampling_rate)\n      data =\n        if input.is_a?(URI)\n          require \"open-uri\"\n\n          input.read\n        elsif input.is_a?(String)\n          File.binread(input)\n        else\n          raise ArgumentError, \"Unsupported input type: #{input.class.name}\"\n        end\n\n      ffmpeg_read(data, sampling_rate)\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/core.rb",
    "content": "module Informers\n  module Utils\n    def self.dispatch_callback(progress_callback, data)\n      progress_callback.(data) if progress_callback\n    end\n\n    def self.calculate_reflect_offset(i, w)\n      ((i + w) % (2 * w) - w).abs\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/dtypes.rb",
    "content": "module Informers\n  module Utils\n    DEFAULT_DTYPE_SUFFIX_MAPPING = {\n      fp32: \"\",\n      fp16: \"_fp16\",\n      int8: \"_int8\",\n      uint8: \"_uint8\",\n      q8: \"_quantized\",\n      q4: \"_q4\",\n      q4f16: \"_q4f16\",\n      bnb4: \"_bnb4\"\n    }\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/ffmpeg.rb",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nmodule Informers\n  module Utils\n    # from the Transformers Python library\n    def self.ffmpeg_read(data, sampling_rate)\n      ar = \"#{sampling_rate}\"\n      ac = \"1\"\n      format_for_conversion = \"f32le\"\n      ffmpeg_command = [\n        \"ffmpeg\",\n        \"-i\",\n        \"pipe:0\",\n        \"-ac\",\n        ac,\n        \"-ar\",\n        ar,\n        \"-f\",\n        format_for_conversion,\n        \"-hide_banner\",\n        \"-loglevel\",\n        \"quiet\",\n        \"pipe:1\"\n      ]\n\n      stdout, status = Open3.capture2(*ffmpeg_command, stdin_data: data)\n      if !status.success?\n        raise Error, \"ffmpeg was not found but is required to load audio files from filename\"\n      end\n      stdout.unpack(\"e*\")\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/generation.rb",
    "content": "module Informers\n  module Utils\n    class GenerationConfig\n      def initialize(kwargs)\n        @config = {}\n\n        # Parameters that control the length of the output\n        @config[\"max_length\"] = kwargs[\"max_length\"] || 20\n        @config[\"max_new_tokens\"] = kwargs[\"max_new_tokens\"]\n        @config[\"min_length\"] = kwargs[\"min_length\"] || 0\n        @config[\"min_new_tokens\"] = kwargs[\"min_new_tokens\"]\n        @config[\"early_stopping\"] = kwargs[\"early_stopping\"] || false\n        @config[\"max_time\"] = kwargs[\"max_time\"]\n\n        # Parameters that control the generation strategy used\n        @config[\"do_sample\"] = kwargs[\"do_sample\"] || false\n        @config[\"num_beams\"] = kwargs[\"num_beams\"] || 1\n        @config[\"num_beam_groups\"] = kwargs[\"num_beam_groups\"] || 1\n        @config[\"penalty_alpha\"] = kwargs[\"penalty_alpha\"]\n        @config[\"use_cache\"] = kwargs.fetch(\"use_cache\", true)\n\n        # Parameters for manipulation of the model output logits\n        @config[\"temperature\"] = kwargs[\"temperature\"] || 1.0\n        @config[\"top_k\"] = kwargs[\"top_k\"] || 50\n        @config[\"top_p\"] = kwargs[\"top_p\"] || 1.0\n        @config[\"typical_p\"] = kwargs[\"typical_p\"] || 1.0\n        @config[\"epsilon_cutoff\"] = kwargs[\"epsilon_cutoff\"] || 0.0\n        @config[\"eta_cutoff\"] = kwargs[\"eta_cutoff\"] || 0.0\n        @config[\"diversity_penalty\"] = kwargs[\"diversity_penalty\"] || 0.0\n        @config[\"repetition_penalty\"] = kwargs[\"repetition_penalty\"] || 1.0\n        @config[\"encoder_repetition_penalty\"] = kwargs[\"encoder_repetition_penalty\"] || 1.0\n        @config[\"length_penalty\"] = kwargs[\"length_penalty\"] || 1.0\n        @config[\"no_repeat_ngram_size\"] = kwargs[\"no_repeat_ngram_size\"] || 0\n        @config[\"bad_words_ids\"] = kwargs[\"bad_words_ids\"]\n        @config[\"force_words_ids\"] = kwargs[\"force_words_ids\"]\n        @config[\"renormalize_logits\"] = kwargs[\"renormalize_logits\"] || false\n        @config[\"constraints\"] = kwargs[\"constraints\"]\n        @config[\"forced_bos_token_id\"] = kwargs[\"forced_bos_token_id\"]\n        @config[\"forced_eos_token_id\"] = kwargs[\"forced_eos_token_id\"]\n        @config[\"remove_invalid_values\"] = kwargs[\"remove_invalid_values\"] || false\n        @config[\"exponential_decay_length_penalty\"] = kwargs[\"exponential_decay_length_penalty\"]\n        @config[\"suppress_tokens\"] = kwargs[\"suppress_tokens\"]\n        @config[\"begin_suppress_tokens\"] = kwargs[\"begin_suppress_tokens\"]\n        @config[\"forced_decoder_ids\"] = kwargs[\"forced_decoder_ids\"]\n\n        # Parameters that define the output variables of `generate`\n        @config[\"num_return_sequences\"] = kwargs[\"num_return_sequences\"] || 1\n        @config[\"output_attentions\"] = kwargs[\"output_attentions\"] || false\n        @config[\"output_hidden_states\"] = kwargs[\"output_hidden_states\"] || false\n        @config[\"output_scores\"] = kwargs[\"output_scores\"] || false\n        @config[\"return_dict_in_generate\"] = kwargs[\"return_dict_in_generate\"] || false\n\n        # Special tokens that can be used at generation time\n        @config[\"pad_token_id\"] = kwargs[\"pad_token_id\"]\n        @config[\"bos_token_id\"] = kwargs[\"bos_token_id\"]\n        @config[\"eos_token_id\"] = kwargs[\"eos_token_id\"]\n\n        # Generation parameters exclusive to encoder-decoder models\n        @config[\"encoder_no_repeat_ngram_size\"] = kwargs[\"encoder_no_repeat_ngram_size\"] || 0\n        @config[\"decoder_start_token_id\"] = kwargs[\"decoder_start_token_id\"]\n\n        # Wild card\n        @generation_kwargs = kwargs[\"generation_kwargs\"] || {}\n      end\n\n      def [](key)\n        @config[key.to_s]\n      end\n\n      def merge!(config)\n        @config.merge!(config)\n      end\n    end\n\n    class Sampler\n      def initialize(generation_config)\n        super()\n        @generation_config = generation_config\n      end\n\n      def call(logits, index = -1)\n        # Sample from logits, of dims [batch, sequence_length, vocab_size].\n        # If index is specified, sample from [batch, index, vocab_size].\n        sample(logits, index)\n      end\n\n      def get_logits(logits, index)\n        vocab_size = Utils.dims(logits)[-1]\n\n        logs = logits.flatten\n\n        if index == -1\n          logs = logs.last(vocab_size)\n        else\n          raise Todo\n        end\n\n        # add temperature\n        if @generation_config[\"temperature\"] > 0\n          logs = logs.map { |x| x / @generation_config[\"temperature\"] }\n        end\n        logs\n      end\n\n      def self.get_sampler(generation_config)\n        if generation_config[:do_sample]\n          MultinomialSampler.new(generation_config)\n        elsif generation_config[:num_beams] > 1\n          BeamSearchSampler.new(generation_config)\n        else\n          if generation_config[:num_return_sequences] > 1\n            raise Error, \"num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}.\"\n          end\n          GreedySampler.new(generation_config)\n        end\n      end\n    end\n\n    class GreedySampler < Sampler\n      def sample(logits, index = -1)\n        # NOTE: no need to do log_softmax here since we only take the maximum\n        logs = get_logits(logits, index)\n        argmax = Utils.max(logs)[1]\n\n        # Note: score is meaningless in this context, since we are performing\n        # greedy search (p = 1 => log(p) = 0)\n        [\n          [argmax, 0]\n        ]\n      end\n    end\n\n    class BeamSearchSampler < Sampler\n      def sample(logits, index = -1)\n        k = Utils.dims(logits)[-1] # defaults to vocab size\n        if @generation_config[\"top_k\"] > 0\n          k = [@generation_config[\"top_k\"], k].min\n        end\n\n        # Get logits of nth token\n        logs = get_logits(logits, index)\n\n        # Get top k tokens\n        top_logits = Utils.get_top_items(logs, k)\n\n        # Compute softmax over logits\n        probabilities = Utils.softmax(top_logits.map { |x| x[1] })\n\n        Array.new(@generation_config[\"num_beams\"]) do |i|\n          [\n            top_logits[i][0],\n            Math.log(probabilities[i])\n          ]\n        end\n      end\n    end\n\n    class LogitsProcessorList\n      def initialize\n        super\n        @processors = []\n      end\n\n      def push(item)\n        @processors << item\n      end\n\n      def concat(items)\n        @processors.concat(items)\n      end\n\n      def call(input_ids, batched_logits)\n        # NOTE: This is different from the Python code, since vanilla Ruby does not support vectorized operations.\n        # As a result, we apply each processor to each item in the batch.\n        batched_logits.each do |logits|\n          # Modifies logits inplace\n          @processors.each do |func|\n            func.(input_ids, logits)\n          end\n        end\n      end\n\n      def to_ary\n        @processors\n      end\n    end\n\n    class LogitsProcessor\n    end\n\n    class NoRepeatNGramLogitsProcessor < LogitsProcessor\n      def initialize(no_repeat_ngram_size)\n        super()\n        @no_repeat_ngram_size = no_repeat_ngram_size\n      end\n\n      def get_ngrams(prev_input_ids)\n        cur_len = prev_input_ids.length\n\n        ngrams = []\n        j = 0\n        while j < cur_len + 1 - @no_repeat_ngram_size\n          ngram = []\n          @no_repeat_ngram_size.times do |k|\n            ngram << prev_input_ids[j + k]\n          end\n          ngrams << ngram\n          j += 1\n        end\n\n        generated_ngram = {}\n        ngrams.each do |ngram|\n          prev_ngram = ngram.slice(0, ngram.length - 1)\n          prev_ngram_key = JSON.generate(prev_ngram)\n          prev_ngram_value = generated_ngram[prev_ngram_key] || []\n          prev_ngram_value << ngram[ngram.length - 1]\n          generated_ngram[prev_ngram_key] = prev_ngram_value\n        end\n        generated_ngram\n      end\n\n      def get_generated_ngrams(banned_ngrams, prev_input_ids)\n        ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length)\n        banned = banned_ngrams[JSON.generate(ngram_idx)] || []\n        banned\n      end\n\n      def calc_banned_ngram_tokens(prev_input_ids)\n        banned_tokens = []\n        if prev_input_ids.length + 1 < @no_repeat_ngram_size\n          # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet\n          banned_tokens\n        else\n          generated_ngrams = get_ngrams(prev_input_ids)\n          banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids)\n          banned_tokens\n        end\n      end\n\n      def call(input_ids, logits)\n        banned_tokens = calc_banned_ngram_tokens(input_ids)\n\n        banned_tokens.each do |token|\n          logits[token] = -Float::INFINITY\n        end\n        logits\n      end\n    end\n\n    class MinLengthLogitsProcessor < LogitsProcessor\n      def initialize(min_length, eos_token_id)\n        super()\n        @min_length = min_length\n        @eos_token_id = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]\n      end\n\n      def call(input_ids, logits)\n        if input_ids.length < @min_length\n          @eos_token_id.each do |eos_token|\n            logits[eos_token] = -Float::INFINITY\n          end\n        end\n\n        logits\n      end\n    end\n\n    class ForcedBOSTokenLogitsProcessor < LogitsProcessor\n      def initialize(bos_token_id)\n        super()\n        @bos_token_id = bos_token_id\n      end\n\n      def call(input_ids, logits)\n        if input_ids.length == 1\n          logits.map! { -Float::INFINITY }\n          logits[@bos_token_id] = 0\n        end\n        logits\n      end\n    end\n\n    class ForcedEOSTokenLogitsProcessor < LogitsProcessor\n      def initialize(max_length, forced_eos_token_id)\n        super()\n        @max_length = max_length\n        @forced_eos_token_id = forced_eos_token_id\n      end\n\n      def call(input_ids, logits)\n      end\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/hub.rb",
    "content": "module Informers\n  module Utils\n    module Hub\n      class FileResponse\n        attr_reader :exists, :status\n\n        def initialize(file_path)\n          @file_path = file_path\n\n          @exists = File.exist?(file_path)\n          if @exists\n            @status = [\"200\", \"OK\"]\n          else\n            @status = [\"404\", \"Not Found\"]\n          end\n        end\n\n        def read\n          File.binread(@file_path)\n        end\n      end\n\n      def self.is_valid_url(string, protocols = nil, valid_hosts = nil)\n        begin\n          url = URI.parse(string)\n        rescue\n          return false\n        end\n        if protocols && !protocols.include?(url.scheme)\n          return false\n        end\n        if valid_hosts && !valid_hosts.include?(url.host)\n          return false\n        end\n        true\n      end\n\n      def self.get_file(url_or_path, progress_callback = nil, progress_info = {})\n        if !is_valid_url(url_or_path, [\"http\", \"https\"])\n          raise Error, \"Invalid url\"\n        else\n          headers = {}\n          headers[\"User-Agent\"] = \"informers/#{VERSION};\"\n\n          # Check whether we are making a request to the Hugging Face Hub.\n          is_hfurl = is_valid_url(url_or_path, [\"http\", \"https\"], [\"huggingface.co\", \"hf.co\"])\n          if is_hfurl\n            # If an access token is present in the environment variables,\n            # we add it to the request headers.\n            token = ENV[\"HF_TOKEN\"]\n            if token\n              headers[\"Authorization\"] = \"Bearer #{token}\"\n            end\n          end\n          options = {}\n          if progress_callback\n            total_size = nil\n            options[:content_length_proc] = lambda do |size|\n              total_size = size\n              Utils.dispatch_callback(progress_callback, {status: \"download\"}.merge(progress_info).merge(total_size: size))\n            end\n            options[:progress_proc] = lambda do |size|\n              Utils.dispatch_callback(progress_callback, {status: \"progress\"}.merge(progress_info).merge(size: size, total_size: total_size))\n            end\n          end\n          URI.parse(url_or_path).open(**headers, **options)\n        end\n      end\n\n      class FileCache\n        attr_reader :path\n\n        def initialize(path)\n          @path = path\n        end\n\n        def match(request)\n          file_path = resolve_path(request)\n          file = FileResponse.new(file_path)\n\n          file if file.exists\n        end\n\n        def put(request, response)\n          output_path = resolve_path(request)\n\n          begin\n            tmp_path = \"#{output_path}.incomplete\"\n            FileUtils.mkdir_p(File.dirname(output_path))\n            File.open(tmp_path, \"wb\") do |f|\n              while !response.eof?\n                f.write(response.read(1024 * 1024))\n              end\n            end\n            FileUtils.move(tmp_path, output_path)\n          rescue => e\n            warn \"An error occurred while writing the file to cache: #{e}\"\n          end\n        end\n\n        def resolve_path(request)\n          File.join(@path, request)\n        end\n      end\n\n      def self.try_cache(cache, *names)\n        names.each do |name|\n          begin\n            result = cache.match(name)\n            return result if result\n          rescue\n            next\n          end\n        end\n        nil\n      end\n\n      def self.get_model_file(path_or_repo_id, filename, fatal = true, **options)\n        # Initiate file retrieval\n        Utils.dispatch_callback(options[:progress_callback], {\n          status: \"initiate\",\n          name: path_or_repo_id,\n          file: filename\n        })\n\n        # If `cache_dir` is not specified, use the default cache directory\n        cache = FileCache.new(options[:cache_dir] || Informers.cache_dir)\n\n        revision = options[:revision] || \"main\"\n\n        request_url = path_join(path_or_repo_id, filename)\n\n        remote_url = path_join(\n          Informers.remote_host,\n          Informers.remote_path_template\n            .gsub(\"{model}\", path_or_repo_id)\n            .gsub(\"{revision}\", URI.encode_www_form_component(revision)),\n          filename\n        )\n\n        # Choose cache key for filesystem cache\n        # When using the main revision (default), we use the request URL as the cache key.\n        # If a specific revision is requested, we account for this in the cache key.\n        fs_cache_key = revision == \"main\" ? request_url : path_join(path_or_repo_id, revision, filename)\n\n        proposed_cache_key = fs_cache_key\n\n        resolved_path = cache.resolve_path(proposed_cache_key)\n\n        # Whether to cache the final response in the end.\n        to_cache_response = false\n\n        # A caching system is available, so we try to get the file from it.\n        response = try_cache(cache, proposed_cache_key)\n\n        cache_hit = !response.nil?\n\n        if response.nil?\n          # File is not cached, so we perform the request\n\n          if response.nil? || response.status[0] == \"404\"\n            # File not found locally. This means either:\n            # - The user has disabled local file access (`Informers.allow_local_models = false`)\n            # - the path is a valid HTTP url (`response.nil?`)\n            # - the path is not a valid HTTP url and the file is not present on the file system or local server (`response.status[0] == \"404\"`)\n\n            if options[:local_files_only] || !Informers.allow_remote_models\n              # User requested local files only, but the file is not found locally.\n              if fatal\n                raise Error, \"`local_files_only: true` or `Informers.allow_remote_models = false` and file was not found locally at #{resolved_path.inspect}.\"\n              else\n                # File not found, but this file is optional.\n                # TODO in future, cache the response?\n                return nil\n              end\n            end\n\n            progress_info = {\n              name: path_or_repo_id,\n              file: filename\n            }\n\n            # File not found locally, so we try to download it from the remote server\n            response = get_file(remote_url, options[:progress_callback], progress_info)\n\n            if response.status[0] != \"200\"\n              # should not happen\n              raise Todo\n            end\n\n            # Success! We use the proposed cache key from earlier\n            cache_key = proposed_cache_key\n          end\n\n          to_cache_response = cache && !response.is_a?(FileResponse) && response.status[0] == \"200\"\n        end\n\n        if to_cache_response && cache_key && cache.match(cache_key).nil?\n          cache.put(cache_key, response)\n        end\n\n        Utils.dispatch_callback(options[:progress_callback], {\n          status: \"done\",\n          name: path_or_repo_id,\n          file: filename,\n          cache_hit: cache_hit\n        })\n\n        resolved_path\n      end\n\n      def self.get_model_json(model_path, file_name, fatal = true, **options)\n        buffer = get_model_file(model_path, file_name, fatal, **options)\n        if buffer.nil?\n          # Return empty object\n          return {}\n        end\n\n        JSON.load_file(buffer)\n      end\n\n      def self.path_join(*parts)\n        parts = parts.map.with_index do |part, index|\n          if index != 0\n            part = part.delete_prefix(\"/\")\n          end\n          if index != parts.length - 1\n            part = part.delete_suffix(\"/\")\n          end\n          part\n        end\n        parts.join(\"/\")\n      end\n\n      def self.display_progress(filename, width, size, expected_size)\n        bar_width = [width - (filename.length + 3), 1].max\n        progress = expected_size && expected_size > 0 ? size / expected_size.to_f : 0\n        done = (progress * bar_width).round\n        not_done = bar_width - done\n        \"#{filename} |#{\"█\" * done}#{\" \" * not_done}|\"\n      end\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/image.rb",
    "content": "module Informers\n  module Utils\n    class RawImage\n      RESAMPLING_MAPPING = {\n        0 => \"nearest\",\n        1 => \"lanczos\",\n        2 => \"bilinear\",\n        3 => \"bicubic\",\n        4 => \"box\",\n        5 => \"hamming\"\n      }\n\n      attr_reader :image, :width, :height, :channels\n\n      def initialize(image)\n        @image = image\n        @width = image.width\n        @height = image.height\n        @channels = image.bands\n      end\n\n      def data\n        @image.write_to_memory.unpack(\"C*\")\n      end\n\n      def size\n        [@width, @height]\n      end\n\n      def resize(width, height, resample: 2)\n        resample_method = RESAMPLING_MAPPING[resample] || resample\n\n        case resample_method\n        when \"bilinear\", \"bicubic\"\n          img =\n            @image.affine(\n              [width / @width.to_f, 0, 0, height / @height.to_f],\n              interpolate: Vips::Interpolate.new(resample_method.to_sym)\n            )\n        else\n          raise Todo\n        end\n\n        RawImage.new(img)\n      end\n\n      def center_crop(crop_width, crop_height)\n        # If the image is already the desired size, return it\n        if @width == crop_width && @height == crop_height\n          return self\n        end\n\n        # Determine bounds of the image in the new canvas\n        width_offset = (@width - crop_width) / 2.0\n        height_offset = (@height - crop_height) / 2.0\n\n        if width_offset >= 0 && height_offset >= 0\n          # Cropped image lies entirely within the original image\n          img = @image.crop(\n            width_offset.floor,\n            height_offset.floor,\n            crop_width,\n            crop_height\n          )\n        elsif width_offset <= 0 && height_offset <= 0\n          raise Todo\n        else\n          raise Todo\n        end\n\n        RawImage.new(img)\n      end\n\n      def rgb\n        if @channels == 3\n          return self\n        end\n\n        raise Todo\n      end\n\n      def save(path)\n        @image.write_to_file(path)\n      end\n\n      def self.read(input)\n        if input.is_a?(RawImage)\n          input\n        elsif input.is_a?(URI)\n          require \"open-uri\"\n\n          RawImage.new(Vips::Image.new_from_buffer(input.read, \"\"))\n        elsif input.is_a?(String)\n          RawImage.new(Vips::Image.new_from_file(input))\n        else\n          raise ArgumentError, \"Unsupported input type: #{input.class.name}\"\n        end\n      end\n\n      def self.from_array(input)\n        c, h, w = Utils.dims(input)\n        pixel_data = Array.new(w * h * c)\n\n        input.each_with_index do |cv, ci|\n          cv.each_with_index do |hv, hi|\n            hv.each_with_index do |v, wi|\n              pixel_data[(hi * w * c) + (wi * c) + ci] = v\n            end\n          end\n        end\n\n        RawImage.new(Vips::Image.new_from_memory_copy(pixel_data.pack(\"C*\"), w, h, c, :uchar))\n      end\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/math.rb",
    "content": "module Informers\n  module Utils\n    def self.interpolate_data(input, in_shape, out_shape, mode = \"bilinear\", align_corners = false)\n      in_channels, in_height, in_width = in_shape\n      out_height, out_width = out_shape\n\n      # TODO use mode and align_corners\n\n      # Output image dimensions\n      x_scale = out_width / in_width.to_f\n      y_scale = out_height / in_height.to_f\n\n      # Output image\n      out_img = Array.new(out_height * out_width * in_channels)\n\n      # Pre-calculate strides\n      in_stride = in_height * in_width\n      out_stride = out_height * out_width\n\n      out_height.times do |i|\n        out_width.times do |j|\n          # Calculate output offset\n          out_offset = i * out_width + j\n\n          # Calculate input pixel coordinates\n          x = (j + 0.5) / x_scale - 0.5\n          y = (i + 0.5) / y_scale - 0.5\n\n          # Calculate the four nearest input pixels\n          # We also check if the input pixel coordinates are within the image bounds\n          x1 = x.floor\n          y1 = y.floor\n          x2 = [x1 + 1, in_width - 1].min\n          y2 = [y1 + 1, in_height - 1].min\n\n          x1 = [x1, 0].max\n          y1 = [y1, 0].max\n\n          # Calculate the fractional distances between the input pixel and the four nearest pixels\n          s = x - x1\n          t = y - y1\n\n          # Perform bilinear interpolation\n          w1 = (1 - s) * (1 - t)\n          w2 = s * (1 - t)\n          w3 = (1 - s) * t\n          w4 = s * t\n\n          # Calculate the four nearest input pixel indices\n          y_stride = y1 * in_width\n          x_stride = y2 * in_width\n          idx1 = y_stride + x1\n          idx2 = y_stride + x2\n          idx3 = x_stride + x1\n          idx4 = x_stride + x2\n\n          in_channels.times do |k|\n            # Calculate channel offset\n            c_offset = k * in_stride\n\n            out_img[k * out_stride + out_offset] =\n              w1 * input[c_offset + idx1] +\n              w2 * input[c_offset + idx2] +\n              w3 * input[c_offset + idx3] +\n              w4 * input[c_offset + idx4]\n          end\n        end\n      end\n\n      out_img\n    end\n\n    def self.softmax(arr)\n      # Compute the maximum value in the array\n      max_val = arr.max\n\n      #  Compute the exponentials of the array values\n      exps = arr.map { |x| Math.exp(x - max_val) }\n\n      # Compute the sum of the exponentials\n      sum_exps = exps.sum\n\n      # Compute the softmax values\n      softmax_arr = exps.map { |x| x / sum_exps }\n\n      softmax_arr\n    end\n\n    def self.sigmoid(arr)\n      if arr[0].is_a?(Array)\n        return arr.map { |a| sigmoid(a) }\n      end\n      arr.map { |v| 1 / (1 + Math.exp(-v)) }\n    end\n\n    def self.get_top_items(items, top_k = 0)\n      # if top == 0, return all\n\n      items = items\n        .map.with_index { |x, i| [i, x] } # Get indices ([index, score])\n        .sort_by { |v| -v[1] }            # Sort by log probabilities\n\n      if !top_k.nil? && top_k > 0\n        items = items.slice(0, top_k)     # Get top k items\n      end\n\n      items\n    end\n\n    def self.max(arr)\n      if arr.length == 0\n        raise Error, \"Array must not be empty\"\n      end\n      arr.map.with_index.max_by { |v, _| v }\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/utils/tensor.rb",
    "content": "module Informers\n  module Utils\n    def self.mean_pooling(last_hidden_state, attention_mask)\n      last_hidden_state.zip(attention_mask).map do |state, mask|\n        state[0].size.times.map do |k|\n          sum = 0.0\n          count = 0\n\n          state.zip(mask) do |s, m|\n            count += m\n            sum += s[k] * m\n          end\n\n          sum / count\n        end\n      end\n    end\n\n    def self.normalize(result)\n      result.map do |row|\n        norm = Math.sqrt(row.sum { |v| v * v })\n        row.map { |v| v / norm }\n      end\n    end\n\n    def self.stack(tensors, dim = 0)\n      tensors\n    end\n\n    def self.ones_like(tensor)\n      if tensor[0].is_a?(Array)\n        return tensor.map { |v| ones_like(v) }\n      end\n      tensor.map { |_| 1 }\n    end\n\n    def self.dims(tensor)\n      dims = []\n      while tensor.is_a?(Array)\n        dims << tensor.size\n        tensor = tensor[0]\n      end\n      dims\n    end\n\n    def self.interpolate(input, shape, mode = \"bilinear\", align_corners = false)\n      out_height, out_width = shape\n\n      # Input image dimensions\n      in_channels = dims(input)[-3] || 1\n      in_height = dims(input)[-2]\n      in_width = dims(input)[-1]\n\n      output = interpolate_data(\n        input.flatten,\n        [in_channels, in_height, in_width],\n        [out_height, out_width],\n        mode,\n        align_corners\n      )\n      reshape(output, [in_channels, out_height, out_width])\n    end\n\n    def self.reshape(arr, dims)\n      arr = arr.flatten\n      dims[1..-1].reverse_each do |dim|\n        arr = arr.each_slice(dim)\n      end\n      arr.to_a\n    end\n  end\nend\n"
  },
  {
    "path": "lib/informers/version.rb",
    "content": "module Informers\n  VERSION = \"1.2.1\"\nend\n"
  },
  {
    "path": "lib/informers.rb",
    "content": "# dependencies\nrequire \"onnxruntime\"\nrequire \"tokenizers\"\n\n# stdlib\nrequire \"io/console\"\nrequire \"json\"\nrequire \"open-uri\"\nrequire \"open3\"\nrequire \"stringio\"\nrequire \"uri\"\n\n# modules\nrequire_relative \"informers/backends/onnx\"\nrequire_relative \"informers/utils/audio\"\nrequire_relative \"informers/utils/core\"\nrequire_relative \"informers/utils/dtypes\"\nrequire_relative \"informers/utils/generation\"\nrequire_relative \"informers/utils/ffmpeg\"\nrequire_relative \"informers/utils/hub\"\nrequire_relative \"informers/utils/image\"\nrequire_relative \"informers/utils/math\"\nrequire_relative \"informers/utils/tensor\"\nrequire_relative \"informers/configs\"\nrequire_relative \"informers/env\"\nrequire_relative \"informers/model\"\nrequire_relative \"informers/models\"\nrequire_relative \"informers/processors\"\nrequire_relative \"informers/tokenizers\"\nrequire_relative \"informers/version\"\nrequire_relative \"informers/pipelines\"\n\nmodule Informers\n  class Error < StandardError; end\n\n  class Todo < Error\n    def message\n      \"not implemented yet\"\n    end\n  end\nend\n"
  },
  {
    "path": "test/model_test.rb",
    "content": "require_relative \"test_helper\"\n\nclass ModelTest < Minitest::Test\n  # https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2\n  def test_all_minilm\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\n    model = Informers.pipeline(\"embedding\", \"sentence-transformers/all-MiniLM-L6-v2\")\n    embeddings = model.(sentences)\n\n    assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2]\n    assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2]\n  end\n\n  # https://huggingface.co/Xenova/all-MiniLM-L6-v2\n  def test_all_minilm_xenova\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\n    model = Informers.pipeline(\"embedding\", \"Xenova/all-MiniLM-L6-v2\", dtype: \"q8\")\n    embeddings = model.(sentences)\n\n    assert_elements_in_delta [0.045927, 0.07328, 0.05401], embeddings[0][..2]\n    assert_elements_in_delta [0.081881, 0.1076, -0.01324], embeddings[1][..2]\n  end\n\n  # https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1\n  def test_multi_qa_minilm\n    query = \"How many people live in London?\"\n    docs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\n    model = Informers.pipeline(\"embedding\", \"sentence-transformers/multi-qa-MiniLM-L6-cos-v1\")\n    query_embedding = model.(query)\n    doc_embeddings = model.(docs)\n    scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }\n    doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }\n\n    assert_equal \"Around 9 Million people live in London\", doc_score_pairs[0][0]\n    assert_in_delta 0.9156, doc_score_pairs[0][1]\n    assert_equal \"London is known for its financial district\", doc_score_pairs[1][0]\n    assert_in_delta 0.4948, doc_score_pairs[1][1]\n  end\n\n  # https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2\n  def test_paraphrase_minilm\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\n    model = Informers.pipeline(\"embedding\", \"sentence-transformers/paraphrase-MiniLM-L6-v2\")\n    embeddings = model.(sentences, normalize: false)\n\n    assert_elements_in_delta [0.067359, 0.783935, 0.270018], embeddings[0][..2]\n    assert_elements_in_delta [0.122117, 0.670228, 0.317166], embeddings[1][..2]\n  end\n\n  # https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1\n  def test_mxbai_embed\n    query_prefix = \"Represent this sentence for searching relevant passages: \"\n\n    input = [\n      \"The dog is barking\",\n      \"The cat is purring\",\n      query_prefix + \"puppy\"\n    ]\n\n    model = Informers.pipeline(\"embedding\", \"mixedbread-ai/mxbai-embed-large-v1\")\n    embeddings = model.(input, pooling: \"cls\", normalize: false)\n\n    assert_elements_in_delta [-0.61227727, 1.4060247, -0.04079155], embeddings[1][..2]\n    assert_elements_in_delta [-0.00624076, 0.12864432, 0.5248165], embeddings[-1][..2]\n  end\n\n  # https://huggingface.co/Supabase/gte-small\n  def test_gte_small\n    sentences = [\"That is a happy person\", \"That is a very happy person\"]\n\n    model = Informers.pipeline(\"embedding\", \"Supabase/gte-small\")\n    embeddings = model.(sentences)\n\n    assert_elements_in_delta [-0.05316979, 0.01044252, 0.06194701], embeddings[0][..2]\n    assert_elements_in_delta [-0.05246907, 0.03752426, 0.07344585], embeddings[-1][..2]\n  end\n\n  # https://huggingface.co/intfloat/e5-base-v2\n  def test_e5_base\n    doc_prefix = \"passage: \"\n    query_prefix = \"query: \"\n\n    input = [\n      doc_prefix + \"Ruby is a programming language created by Matz\",\n      query_prefix + \"Ruby creator\"\n    ]\n\n    model = Informers.pipeline(\"embedding\", \"intfloat/e5-base-v2\")\n    embeddings = model.(input)\n\n    assert_elements_in_delta [-0.00596662, -0.03730119, -0.0703470], embeddings[0][..2]\n    assert_elements_in_delta [0.00298353, -0.04421991, -0.0591884], embeddings[-1][..2]\n  end\n\n  # https://huggingface.co/nomic-ai/nomic-embed-text-v1\n  def test_nomic_embed\n    doc_prefix = \"search_document: \"\n    query_prefix = \"search_query: \"\n\n    input = [\n      doc_prefix + \"The dog is barking\",\n      query_prefix + \"puppy\"\n    ]\n\n    model = Informers.pipeline(\"embedding\", \"nomic-ai/nomic-embed-text-v1\")\n    embeddings = model.(input)\n\n    assert_elements_in_delta [-0.00645858, 0.01145126, 0.0099767], embeddings[0][..2]\n    assert_elements_in_delta [-0.01173127, 0.04957652, -0.0176401], embeddings[-1][..2]\n  end\n\n  # https://huggingface.co/BAAI/bge-base-en-v1.5\n  def test_bge_base\n    query_prefix = \"Represent this sentence for searching relevant passages: \"\n\n    input = [\n      \"The dog is barking\",\n      \"The cat is purring\",\n      query_prefix + \"puppy\"\n    ]\n\n    model = Informers.pipeline(\"embedding\", \"BAAI/bge-base-en-v1.5\")\n    embeddings = model.(input)\n\n    assert_elements_in_delta [-0.07482512, -0.0770234, 0.03398684], embeddings[1][..2]\n    assert_elements_in_delta [0.00029264, -0.0619305, -0.06199387], embeddings[-1][..2]\n  end\n\n  # https://huggingface.co/jinaai/jina-embeddings-v2-base-en\n  def test_jina_embeddings\n    sentences = [\"How is the weather today?\", \"What is the current weather like today?\"]\n\n    model = Informers.pipeline(\"embedding\", \"jinaai/jina-embeddings-v2-base-en\", model_file_name: \"../model\")\n    embeddings = model.(sentences)\n\n    assert_elements_in_delta [-0.02488641, -0.0429398, 0.04303398], embeddings[0][..2]\n    assert_elements_in_delta [-0.0081194, -0.06225249, 0.03116853], embeddings[1][..2]\n  end\n\n  # https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5\n  def test_snowflake_arctic_embed\n    query_prefix = \"Represent this sentence for searching relevant passages: \"\n\n    input = [\n      \"The dog is barking\",\n      \"The cat is purring\",\n      query_prefix + \"puppy\"\n    ]\n\n    model = Informers.pipeline(\"embedding\", \"Snowflake/snowflake-arctic-embed-m-v1.5\")\n    embeddings = model.(input, model_output: \"sentence_embedding\", pooling: \"none\")\n\n    assert_elements_in_delta [0.03239886, 0.0009998, 0.08401278], embeddings[0][..2]\n    assert_elements_in_delta [-0.02530634, -0.02715422, 0.01218867], embeddings[-1][..2]\n\n    embeddings = model.(input, model_output: \"token_embeddings\", pooling: \"cls\")\n\n    assert_elements_in_delta [0.03239886, 0.0009998, 0.08401278], embeddings[0][..2]\n    assert_elements_in_delta [-0.02530634, -0.02715422, 0.01218867], embeddings[-1][..2]\n  end\n\n  # https://huggingface.co/sentence-transformers/all-mpnet-base-v2\n  def test_all_mpnet\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\n    model = Informers.pipeline(\"embedding\", \"sentence-transformers/all-mpnet-base-v2\")\n    embeddings = model.(sentences)\n\n    assert_elements_in_delta [0.02250263, -0.07829167, -0.02303071], embeddings[0][..2]\n    assert_elements_in_delta [0.04170236, 0.00109747, -0.01553415], embeddings[1][..2]\n  end\n\n  # https://huggingface.co/BAAI/bge-m3\n  def test_bge_m3\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n\n    model = Informers.pipeline(\"embedding\", \"BAAI/bge-m3\")\n    model.(sentences, model_output: \"token_embeddings\")\n  end\n\n  # https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1\n  def test_mxbai_rerank\n    query = \"How many people live in London?\"\n    docs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\n    model = Informers.pipeline(\"reranking\", \"mixedbread-ai/mxbai-rerank-base-v1\")\n    result = model.(query, docs, return_documents: true)\n\n    assert_equal 0, result[0][:doc_id]\n    assert_in_delta 0.984, result[0][:score]\n    assert_equal docs[0], result[0][:text]\n\n    assert_equal 1, result[1][:doc_id]\n    assert_in_delta 0.139, result[1][:score]\n    assert_equal docs[1], result[1][:text]\n  end\n\n  # https://huggingface.co/jinaai/jina-reranker-v1-turbo-en\n  def test_jina_reranker\n    query = \"How many people live in London?\"\n    docs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\n    model = Informers.pipeline(\"reranking\", \"jinaai/jina-reranker-v1-turbo-en\")\n    result = model.(query, docs, return_documents: true)\n\n    assert_equal 0, result[0][:doc_id]\n    assert_in_delta 0.912, result[0][:score]\n    assert_equal docs[0], result[0][:text]\n\n    assert_equal 1, result[1][:doc_id]\n    assert_in_delta 0.0555, result[1][:score]\n    assert_equal docs[1], result[1][:text]\n  end\n\n  # https://huggingface.co/BAAI/bge-reranker-base\n  def test_bge_reranker\n    query = \"How many people live in London?\"\n    docs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\n    model = Informers.pipeline(\"reranking\", \"BAAI/bge-reranker-base\")\n    result = model.(query, docs, return_documents: true)\n\n    assert_equal 0, result[0][:doc_id]\n    assert_in_delta 0.996, result[0][:score]\n    assert_equal docs[0], result[0][:text]\n\n    assert_equal 1, result[1][:doc_id]\n    assert_in_delta 0.000158, result[1][:score], 0.000001\n    assert_equal docs[1], result[1][:text]\n  end\n\n  # https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2\n  def test_ms_marco_minilm\n    query = \"How many people live in London?\"\n    docs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n\n    model = Informers.pipeline(\"reranking\", \"Xenova/ms-marco-MiniLM-L-6-v2\")\n    result = model.(query, docs, return_documents: true)\n\n    assert_equal 0, result[0][:doc_id]\n    assert_in_delta 1, result[0][:score]\n    assert_equal docs[0], result[0][:text]\n\n    assert_equal 1, result[1][:doc_id]\n    assert_in_delta 0.0067, result[1][:score]\n    assert_equal docs[1], result[1][:text]\n  end\nend\n"
  },
  {
    "path": "test/pipeline_test.rb",
    "content": "require_relative \"test_helper\"\n\nclass PipelineTest < Minitest::Test\n  def test_ner\n    ner = Informers.pipeline(\"ner\")\n    result = ner.(\"Ruby is a programming language created by Matz\")\n    assert_equal 1, result.size\n    assert_equal \"PER\", result[0][:entity_group]\n    assert_in_delta 0.994, result[0][:score]\n    assert_equal \"Matz\", result[0][:word]\n    assert_equal 42, result[0][:start]\n    assert_equal 46, result[0][:end]\n  end\n\n  def test_ner_aggregation_strategy\n    ner = Informers.pipeline(\"ner\")\n    result = ner.(\"Ruby is a programming language created by Matz\", aggregation_strategy: \"none\")\n    assert_equal 2, result.size\n    assert_equal \"B-PER\", result[0][:entity]\n    assert_in_delta 0.996, result[0][:score]\n    assert_equal 8, result[0][:index]\n    assert_equal \"Mat\", result[0][:word]\n    assert_equal 42, result[0][:start]\n    assert_equal 45, result[0][:end]\n  end\n\n  def test_sentiment_analysis\n    classifier = Informers.pipeline(\"sentiment-analysis\")\n    result = classifier.(\"I love transformers!\")\n    assert_equal \"POSITIVE\", result[:label]\n    assert_in_delta 0.9997887, result[:score], 0.0000001\n\n    result = classifier.(\"This is super cool\")\n    assert_equal \"POSITIVE\", result[:label]\n    assert_in_delta 0.9998608, result[:score], 0.0000001\n\n    result = classifier.([\"This is super cool\", \"I didn't like it\"])\n    assert_equal \"POSITIVE\", result[0][:label]\n    assert_in_delta 0.9998600, result[0][:score], 0.0000001\n    assert_equal \"NEGATIVE\", result[1][:label]\n    assert_in_delta 0.9985375, result[1][:score], 0.0000001\n  end\n\n  def test_question_answering\n    qa = Informers.pipeline(\"question-answering\")\n    result = qa.(\"Who invented Ruby?\", \"Ruby is a programming language created by Matz\")\n    assert_in_delta 0.998, result[:score]\n    assert_equal \"Matz\", result[:answer]\n    assert_equal 42, result[:start]\n    assert_equal 46, result[:end]\n  end\n\n  def test_zero_shot_classification\n    classifier = Informers.pipeline(\"zero-shot-classification\")\n    text = \"Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.\"\n    labels = [\"mobile\", \"billing\", \"website\", \"account access\"]\n    result = classifier.(text, labels)\n    assert_equal text, result[:sequence]\n    assert_equal [\"mobile\", \"billing\", \"account access\", \"website\"], result[:labels]\n    assert_elements_in_delta [0.633, 0.134, 0.121, 0.111], result[:scores]\n  end\n\n  def test_text2text_generation\n    text2text = Informers.pipeline(\"text2text-generation\")\n    result = text2text.(\"translate from English to French: I'm very happy\")\n    assert_equal \"Je suis très heureux.\", result[0][:generated_text]\n  end\n\n  def test_translation\n    translator = Informers.pipeline(\"translation\", \"Xenova/nllb-200-distilled-600M\")\n    result = translator.(\"जीवन एक चॉकलेट बॉक्स की तरह है।\", src_lang: \"hin_Deva\", tgt_lang: \"fra_Latn\")\n    assert_equal \"La vie est comme une boîte à chocolat.\", result[0][:translation_text]\n  end\n\n  def test_text_generation\n    generator = Informers.pipeline(\"text-generation\")\n    result = generator.(\"I enjoy walking with my cute dog,\")\n    assert_equal \"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\", result[0][:generated_text]\n  end\n\n  def test_summarization\n    skip \"TODO\"\n\n    summarizer = Informers.pipeline(\"summarization\")\n    result = summarizer.(\"Ruby is awesome.\")\n    assert_equal \"Ruby is awesome. Ruby is awesome. Ruby is great. Ruby's website is great. Ruby's site is great for the first time. Ruby will be great for all the people who want to know more about the site. Click here for more information. Click HERE for\", result[0][:summary_text]\n  end\n\n  def test_fill_mask\n    unmasker = Informers.pipeline(\"fill-mask\")\n    result = unmasker.(\"Paris is the [MASK] of France.\")\n    assert_equal 5, result.size\n    assert_in_delta 0.997, result[0][:score]\n    assert_equal 3007, result[0][:token]\n    assert_equal \"capital\", result[0][:token_str]\n    assert_equal \"paris is the capital of france.\", result[0][:sequence]\n  end\n\n  def test_fill_mask_no_mask_token\n    unmasker = Informers.pipeline(\"fill-mask\")\n    error = assert_raises(ArgumentError) do\n      unmasker.(\"Paris is the <mask> of France.\")\n    end\n    assert_equal \"Mask token ([MASK]) not found in text.\", error.message\n  end\n\n  def test_feature_extraction\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n    extractor = Informers.pipeline(\"feature-extraction\")\n    output = extractor.(sentences)\n    assert_in_delta (-0.0145), output[0][0][0]\n    assert_in_delta (-0.3130), output[-1][-1][-1]\n  end\n\n  def test_embedding\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n    embed = Informers.pipeline(\"embedding\")\n    embeddings = embed.(sentences)\n    assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2]\n    assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2]\n  end\n\n  def test_reranking\n    query = \"How many people live in London?\"\n    docs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n    rerank = Informers.pipeline(\"reranking\")\n    result = rerank.(query, docs)\n    assert_equal 2, result.size\n    assert_equal 0, result[0][:doc_id]\n    assert_in_delta 0.984, result[0][:score]\n    assert_equal 1, result[1][:doc_id]\n    assert_in_delta 0.139, result[1][:score]\n  end\n\n  def test_image_classification\n    classifier = Informers.pipeline(\"image-classification\")\n    result = classifier.(\"test/support/pipeline-cat-chonk.jpeg\", top_k: 2)\n    assert_equal \"lynx, catamount\", result[0][:label]\n    assert_in_delta 0.428, result[0][:score], 0.01\n    assert_equal \"cougar, puma, catamount, mountain lion, painter, panther, Felis concolor\", result[1][:label]\n    assert_in_delta 0.047, result[1][:score], 0.01\n  end\n\n  def test_zero_shot_image_classification\n    classifier = Informers.pipeline(\"zero-shot-image-classification\")\n    result = classifier.(\"test/support/pipeline-cat-chonk.jpeg\", [\"dog\", \"cat\", \"tiger\"])\n    assert_equal 3, result.size\n    assert_equal \"cat\", result[0][:label]\n    assert_in_delta 0.756, result[0][:score]\n    assert_equal \"tiger\", result[1][:label]\n    assert_in_delta 0.189, result[1][:score]\n    assert_equal \"dog\", result[2][:label]\n    assert_in_delta 0.055, result[2][:score]\n  end\n\n  def test_object_detection\n    detector = Informers.pipeline(\"object-detection\")\n    result = detector.(\"test/support/pipeline-cat-chonk.jpeg\")\n    assert_equal 3, result.size\n\n    assert_equal \"cat\", result[0][:label]\n    assert_in_delta 0.992, result[0][:score]\n    assert_equal 177, result[0][:box][:xmin]\n    assert_equal 153, result[0][:box][:ymin]\n    assert_equal 959, result[0][:box][:xmax]\n    assert_equal 600, result[0][:box][:ymax]\n\n    assert_equal \"bicycle\", result[2][:label]\n    assert_in_delta 0.726, result[2][:score]\n    assert_equal 0, result[2][:box][:xmin]\n    assert_equal 0, result[2][:box][:ymin]\n    assert_equal 196, result[2][:box][:xmax]\n    assert_equal 413, result[2][:box][:ymax]\n  end\n\n  def test_zero_shot_object_detection\n    detector = Informers.pipeline(\"zero-shot-object-detection\")\n    result = detector.(\"test/support/zero-sh-obj-detection_1.png\", [\"human face\", \"rocket\", \"helmet\", \"american flag\"])\n    assert_equal 4, result.size\n\n    assert_equal \"human face\", result[0][:label]\n    assert_in_delta 0.351, result[0][:score]\n    assert_equal 179, result[0][:box][:xmin]\n    assert_equal 72, result[0][:box][:ymin]\n    assert_equal 270, result[0][:box][:xmax]\n    assert_equal 178, result[0][:box][:ymax]\n\n    assert_equal \"rocket\", result[1][:label]\n    assert_in_delta 0.211, result[1][:score]\n    assert_equal 351, result[1][:box][:xmin]\n    assert_equal 6, result[1][:box][:ymin]\n    assert_equal 468, result[1][:box][:xmax]\n    assert_equal 289, result[1][:box][:ymax]\n  end\n\n  def test_depth_estimation\n    estimator = Informers.pipeline(\"depth-estimation\")\n    result = estimator.(\"test/support/pipeline-cat-chonk.jpeg\")\n    assert_in_delta 1.078, result[:predicted_depth][0][0]\n    assert_kind_of Vips::Image, result[:depth]\n    # result[:depth].write_to_file(\"/tmp/depth-estimation.jpg\")\n  end\n\n  def test_image_to_text\n    captioner = Informers.pipeline(\"image-to-text\")\n    result = captioner.(\"test/support/pipeline-cat-chonk.jpeg\")\n    assert_equal \"a cat is standing in the snow\", result[0][:generated_text]\n  end\n\n  def test_image_to_image\n    skip \"Expensive\"\n\n    upscaler = Informers.pipeline(\"image-to-image\")\n    result = upscaler.(\"test/support/pipeline-cat-chonk.jpeg\")\n    assert_kind_of Vips::Image, result\n    result.write_to_file(\"/tmp/image-to-image.jpg\")\n  end\n\n  def test_image_segmentation\n    segmenter = Informers.pipeline(\"image-segmentation\")\n    result = segmenter.(\"test/support/pipeline-cat-chonk.jpeg\")\n    assert_equal 3, result.size\n\n    assert_equal \"snow\", result[0][:label]\n    assert_in_delta 0.997, result[0][:score]\n    assert_equal \"LABEL_184\", result[1][:label]\n    assert_in_delta 0.993, result[1][:score]\n    assert_equal \"cat\", result[2][:label]\n    assert_in_delta 0.998, result[2][:score]\n  end\n\n  def test_image_feature_extraction\n    fe = Informers.pipeline(\"image-feature-extraction\")\n    result = fe.(\"test/support/pipeline-cat-chonk.jpeg\")\n    assert_in_delta 0.877, result[0][0], 0.01\n  end\n\n  def test_progress_callback\n    msgs = []\n    extractor = Informers.pipeline(\"feature-extraction\", progress_callback: ->(msg) { msgs << msg })\n    extractor.(\"I love transformers!\")\n\n    expected_msgs = [\n      {status: \"initiate\", name: \"Xenova/all-MiniLM-L6-v2\", file: \"tokenizer.json\"},\n      {status: \"ready\", task: \"feature-extraction\", model: \"Xenova/all-MiniLM-L6-v2\"}\n    ]\n    expected_msgs.each do |expected|\n      assert_includes msgs, expected\n    end\n  end\n\n  def test_device\n    skip unless mac?\n\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n    embed = Informers.pipeline(\"embedding\", \"Xenova/all-MiniLM-L6-v2\", device: \"coreml\")\n    embeddings = embed.(sentences)\n    assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2]\n    assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2]\n  end\n\n  def test_device_invalid\n    error = assert_raises(ArgumentError) do\n      Informers.pipeline(\"embedding\", device: \"bad\")\n    end\n    assert_equal \"Unsupported device: bad. Should be one of: cpu, cuda, coreml\", error.message\n  end\n\n  def test_dtype\n    sentences = [\"This is an example sentence\", \"Each sentence is converted\"]\n    embed = Informers.pipeline(\"embedding\", \"Xenova/all-MiniLM-L6-v2\", dtype: \"fp16\")\n    embeddings = embed.(sentences)\n    assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2]\n    assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2]\n  end\n\n  def test_dtype_invalid\n    error = assert_raises(ArgumentError) do\n      Informers.pipeline(\"embedding\", dtype: \"bad\")\n    end\n    assert_equal \"Invalid dtype: bad. Should be one of: fp32, fp16, int8, uint8, q8, q4, q4f16, bnb4\", error.message\n  end\n\n  def test_session_options\n    # TODO improve test\n    Informers.pipeline(\"embedding\", session_options: {log_severity_level: 2})\n  end\nend\n"
  },
  {
    "path": "test/test_helper.rb",
    "content": "require \"bundler/setup\"\nBundler.require(:default)\nrequire \"minitest/autorun\"\n\nclass Minitest::Test\n  def assert_elements_in_delta(expected, actual, delta = 0.001)\n    assert_equal expected.size, actual.size\n    expected.zip(actual) do |exp, act|\n      assert_in_delta exp, act, delta\n    end\n  end\n\n  def mac?\n    RbConfig::CONFIG[\"host_os\"] =~ /darwin/i\n  end\nend\n"
  }
]