[
  {
    "path": ".github/workflows/main.yml",
    "content": "name: Python package\n\non: [push]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.10\"]\n\n    steps:\n      - uses: actions/checkout@v3\n      # - uses: pre-commit/action@v3.0.0\n      #   name: Run pre-commit checks (pylint/yapf/isort)\n      #   env:\n      #     SKIP: insert-license\n      #   with:\n      #     extra_args: --hook-stage push --all-files\n      - uses: actions/setup-python@v4\n        with:\n          python-version: \"3.10\"\n          cache: \"pip\" # caching pip dependencies\n      - name: install packages\n        run: |\n          /usr/bin/python -m pip install --upgrade pip\n          pip install --no-deps -r images/requirements.txt\n          # - name: ssh access\n          #   uses: lhotari/action-upterm@v1\n          #   with:\n          #     limit-access-to-actor: true\n          #     limit-access-to-users: arashd\n      - name: run tests\n        run: |\n          # Environment variables are reset in between steps.\n          mkdir /tmp/github_testing\n          ln -s $GITHUB_WORKSPACE /tmp/github_testing/tml\n          export PYTHONPATH=\"/tmp/github_testing:$PYTHONPATH\"\n          pytest -vv\n"
  },
  {
    "path": ".gitignore",
    "content": "# Mac\n.DS_Store\n\n# Vim\n*.py.swp\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n\n# C extensions\n*.so\n\n# Distribution / packaging\nbuild/\ndevelop-eggs/\ndist/\neggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\n.hypothesis\n\nvenv\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n-   repo: https://github.com/pausan/cblack\n    rev: release-22.3.0\n    hooks:\n    - id: cblack\n      name: cblack\n      description: \"Black: The uncompromising Python code formatter - 2 space indent fork\"\n      entry: cblack . -l 100\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v2.3.0\n    hooks:\n    -   id: trailing-whitespace\n    -   id: end-of-file-fixer\n    -   id: check-yaml\n    -   id: check-added-large-files\n    -   id: check-merge-conflict\n"
  },
  {
    "path": "COPYING",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU Affero General Public License is a free, copyleft license for\nsoftware and other kinds of works, specifically designed to ensure\ncooperation with the community in the case of network server software.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nour General Public Licenses are intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  Developers that use our General Public Licenses protect your rights\nwith two steps: (1) assert copyright on the software, and (2) offer\nyou this License which gives you legal permission to copy, distribute\nand/or modify the software.\n\n  A secondary benefit of defending all users' freedom is that\nimprovements made in alternate versions of the program, if they\nreceive widespread use, become available for other developers to\nincorporate.  Many developers of free software are heartened and\nencouraged by the resulting cooperation.  However, in the case of\nsoftware used on network servers, this result may fail to come about.\nThe GNU General Public License permits making a modified version and\nletting the public access it on a server without ever releasing its\nsource code to the public.\n\n  The GNU Affero General Public License is designed specifically to\nensure that, in such cases, the modified source code becomes available\nto the community.  It requires the operator of a network server to\nprovide the source code of the modified version running there to the\nusers of that server.  Therefore, public use of a modified version, on\na publicly accessible server, gives the public access to the source\ncode of the modified version.\n\n  An older license, called the Affero General Public License and\npublished by Affero, was designed to accomplish similar goals.  This is\na different license, not a version of the Affero GPL, but Affero has\nreleased a new version of the Affero GPL which permits relicensing under\nthis license.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Remote Network Interaction; Use with the GNU General Public License.\n\n  Notwithstanding any other provision of this License, if you modify the\nProgram, your modified version must prominently offer all users\ninteracting with it remotely through a computer network (if your version\nsupports such interaction) an opportunity to receive the Corresponding\nSource of your version by providing access to the Corresponding Source\nfrom a network server at no charge, through some standard or customary\nmeans of facilitating copying of software.  This Corresponding Source\nshall include the Corresponding Source for any work covered by version 3\nof the GNU General Public License that is incorporated pursuant to the\nfollowing paragraph.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the work with which it is combined will remain governed by version\n3 of the GNU General Public License.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU Affero General Public License from time to time.  Such new versions\nwill be similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU Affero General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU Affero General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU Affero General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU Affero General Public License as published by\n    the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU Affero General Public License for more details.\n\n    You should have received a copy of the GNU Affero General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If your software can interact with users remotely through a computer\nnetwork, you should also make sure that it provides a way for users to\nget its source.  For example, if your program is a web application, its\ninterface could display a \"Source\" link that leads users to an archive\nof the code.  There are many ways you could offer source, and different\nsolutions will be better for different programs; see section 13 for the\nspecific requirements.\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU AGPL, see\n<https://www.gnu.org/licenses/>.\n"
  },
  {
    "path": "LICENSE.torchrec",
    "content": "A few files here (where it is specifically noted in comments) are based on code from torchrec but\nadapted for our use. Torchrec license is below:\n\n\nBSD 3-Clause License\n\nCopyright (c) Meta Platforms, Inc. and affiliates.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "This project open sources some of the ML models used at Twitter.\n\nCurrently these are:\n\n1. The \"For You\" Heavy Ranker (projects/home/recap).\n\n2. TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387\n\n\nThis project can be run inside a python virtualenv. We have only tried this on Linux machines and because we use torchrec it works best with an Nvidia GPU. To setup run\n\n`./images/init_venv.sh` (Linux only).\n\nThe READMEs of each project contain instructions about how to run each project.\n"
  },
  {
    "path": "common/__init__.py",
    "content": ""
  },
  {
    "path": "common/batch.py",
    "content": "\"\"\"Extension of torchrec.dataset.utils.Batch to cover any dataset.\n\"\"\"\n# flake8: noqa\nfrom __future__ import annotations\nfrom typing import Dict\nimport abc\nfrom dataclasses import dataclass\nimport dataclasses\n\nimport torch\nfrom torchrec.streamable import Pipelineable\n\n\nclass BatchBase(Pipelineable, abc.ABC):\n  @abc.abstractmethod\n  def as_dict(self) -> Dict:\n    raise NotImplementedError\n\n  def to(self, device: torch.device, non_blocking: bool = False):\n    args = {}\n    for feature_name, feature_value in self.as_dict().items():\n      args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)\n    return self.__class__(**args)\n\n  def record_stream(self, stream: torch.cuda.streams.Stream) -> None:\n    for feature_value in self.as_dict().values():\n      feature_value.record_stream(stream)\n\n  def pin_memory(self):\n    args = {}\n    for feature_name, feature_value in self.as_dict().items():\n      args[feature_name] = feature_value.pin_memory()\n    return self.__class__(**args)\n\n  def __repr__(self) -> str:\n    def obj2str(v):\n      return f\"{v.size()}\" if hasattr(v, \"size\") else f\"{v.length_per_key()}\"\n\n    return \"\\n\".join([f\"{k}: {obj2str(v)},\" for k, v in self.as_dict().items()])\n\n  @property\n  def batch_size(self) -> int:\n    for tensor in self.as_dict().values():\n      if tensor is None:\n        continue\n      if not isinstance(tensor, torch.Tensor):\n        continue\n      return tensor.shape[0]\n    raise Exception(\"Could not determine batch size from tensors.\")\n\n\n@dataclass\nclass DataclassBatch(BatchBase):\n  @classmethod\n  def feature_names(cls):\n    return list(cls.__dataclass_fields__.keys())\n\n  def as_dict(self):\n    return {\n      feature_name: getattr(self, feature_name)\n      for feature_name in self.feature_names()\n      if hasattr(self, feature_name)\n    }\n\n  @staticmethod\n  def from_schema(name: str, schema):\n    \"\"\"Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.\"\"\"\n    return dataclasses.make_dataclass(\n      cls_name=name,\n      fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],\n      bases=(DataclassBatch,),\n    )\n\n  @staticmethod\n  def from_fields(name: str, fields: dict):\n    return dataclasses.make_dataclass(\n      cls_name=name,\n      fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],\n      bases=(DataclassBatch,),\n    )\n\n\nclass DictionaryBatch(BatchBase, dict):\n  def as_dict(self) -> Dict:\n    return self\n"
  },
  {
    "path": "common/checkpointing/__init__.py",
    "content": "from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot\n"
  },
  {
    "path": "common/checkpointing/snapshot.py",
    "content": "import os\nimport time\nfrom typing import Any, Dict, List, Optional\n\nfrom tml.ml_logging.torch_logging import logging\nfrom tml.common.filesystem import infer_fs, is_gcs_fs\n\nimport torchsnapshot\n\n\nDONE_EVAL_SUBDIR = \"evaled_by\"\nGCS_PREFIX = \"gs://\"\n\n\nclass Snapshot:\n  \"\"\"Checkpoints using torchsnapshot.\n\n  Also saves step to be updated by the training loop.\n\n  \"\"\"\n\n  def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:\n    self.save_dir = save_dir\n    self.state = state\n    self.state[\"extra_state\"] = torchsnapshot.StateDict(step=0, walltime=0.0)\n\n  @property\n  def step(self):\n    return self.state[\"extra_state\"][\"step\"]\n\n  @step.setter\n  def step(self, step: int) -> None:\n    self.state[\"extra_state\"][\"step\"] = step\n\n  @property\n  def walltime(self):\n    return self.state[\"extra_state\"][\"walltime\"]\n\n  @walltime.setter\n  def walltime(self, walltime: float) -> None:\n    self.state[\"extra_state\"][\"walltime\"] = walltime\n\n  def save(self, global_step: int) -> \"PendingSnapshot\":\n    \"\"\"Saves checkpoint with given global_step.\"\"\"\n    path = os.path.join(self.save_dir, str(global_step))\n    logging.info(f\"Saving snapshot global_step {global_step} to {path}.\")\n    start_time = time.time()\n    # Take a snapshot in async manner, the snapshot is consistent that state changes after this method returns have no effect on the snapshot. It performs storage I/O in the background.\n    snapshot = torchsnapshot.Snapshot.async_take(\n      app_state=self.state,\n      path=path,\n      # commented out because DistributedModelParallel model saving\n      # errors with this on multi-GPU. With it removed, CPU, single\n      # GPU, and multi-GPU training all successfully checkpoint.\n      # replicated=[\"**\"],\n    )\n    logging.info(f\"Snapshot saved to {snapshot.path} ({time.time() - start_time:.05}s\")\n    return snapshot\n\n  def restore(self, checkpoint: str) -> None:\n    \"\"\"Restores a given checkpoint.\"\"\"\n    snapshot = torchsnapshot.Snapshot(path=checkpoint)\n    logging.info(f\"Restoring snapshot from {snapshot.path}.\")\n    start_time = time.time()\n    # We can remove the try-except when we are confident that we no longer need to restore from\n    # checkpoints from before walltime was added\n    try:\n      # checkpoints that do not have extra_state[walltime] will fail here\n      snapshot.restore(self.state)\n    except RuntimeError:\n      # extra_state[walltime] does not exist in the checkpoint, but step should be there so restore it\n      self.state[\"extra_state\"] = torchsnapshot.StateDict(step=0)\n      snapshot.restore(self.state)\n      # we still need to ensure that extra_state has walltime in it\n      self.state[\"extra_state\"] = torchsnapshot.StateDict(step=self.step, walltime=0.0)\n\n    logging.info(f\"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s\")\n\n  @classmethod\n  def get_torch_snapshot(\n    cls,\n    snapshot_path: str,\n    global_step: Optional[int] = None,\n    missing_ok: bool = False,\n  ) -> torchsnapshot.Snapshot:\n    \"\"\"Get torch stateless snapshot, without actually loading it.\n    Args:\n      snapshot_path: path to the model snapshot\n      global_step: restores from this checkpoint if specified.\n      missing_ok: if True and checkpoints do not exist, returns without restoration.\n    \"\"\"\n    path = get_checkpoint(snapshot_path, global_step, missing_ok)\n    logging.info(f\"Loading snapshot from {path}.\")\n    return torchsnapshot.Snapshot(path=path)\n\n  @classmethod\n  def load_snapshot_to_weight(\n    cls,\n    embedding_snapshot: torchsnapshot.Snapshot,\n    snapshot_emb_name: str,\n    weight_tensor,\n  ) -> None:\n    \"\"\"Loads pretrained embedding from the snapshot to the model.\n       Utilise partial lodaing meachanism from torchsnapshot.\n    Args:\n      embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).\n      snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.\n      weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.\n    \"\"\"\n    start_time = time.time()\n    manifest = embedding_snapshot.get_manifest()\n    for path in manifest.keys():\n      if path.startswith(\"0\") and snapshot_emb_name in path:\n        snapshot_path_to_load = path\n    embedding_snapshot.read_object(snapshot_path_to_load, weight_tensor)\n    logging.info(\n      f\"Loaded embedding snapshot from {snapshot_path_to_load}: {time.time() - start_time:.05}s\",\n      rank=-1,\n    )\n    logging.info(f\"Snapshot loaded to {weight_tensor.metadata()}\", rank=-1)\n\n\ndef _eval_subdir(checkpoint_path: str) -> str:\n  return os.path.join(checkpoint_path, DONE_EVAL_SUBDIR)\n\n\ndef _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:\n  return os.path.join(_eval_subdir(checkpoint_path), f\"{eval_partition}_DONE\")\n\n\ndef is_done_eval(checkpoint_path: str, eval_partition: str):\n  return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))\n\n\ndef mark_done_eval(checkpoint_path: str, eval_partition: str):\n  infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))\n\n\ndef step_from_checkpoint(checkpoint: str) -> int:\n  return int(os.path.basename(checkpoint))\n\n\ndef checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):\n  \"\"\"Simplified equivalent of tf.train.checkpoints_iterator.\n\n  Args:\n    seconds_to_sleep: time between polling calls.\n    timeout: how long to wait for a new checkpoint.\n\n  \"\"\"\n\n  def _poll(last_checkpoint: Optional[str] = None):\n    stop_time = time.time() + timeout\n    while True:\n      _checkpoint_path = get_checkpoint(save_dir, missing_ok=True)\n      if not _checkpoint_path or _checkpoint_path == last_checkpoint:\n        if time.time() + seconds_to_sleep > stop_time:\n          logging.info(\n            f\"Timed out waiting for next available checkpoint from {save_dir} for {timeout}s.\"\n          )\n          return None\n        logging.info(f\"Waiting for next available checkpoint from {save_dir}.\")\n        time.sleep(seconds_to_sleep)\n      else:\n        logging.info(f\"Found latest checkpoint {_checkpoint_path}.\")\n        return _checkpoint_path\n\n  checkpoint_path = None\n  while True:\n    new_checkpoint = _poll(checkpoint_path)\n    if not new_checkpoint:\n      return\n    checkpoint_path = new_checkpoint\n    yield checkpoint_path\n\n\ndef get_checkpoint(\n  save_dir: str,\n  global_step: Optional[int] = None,\n  missing_ok: bool = False,\n) -> str:\n  \"\"\"Gets latest checkpoint or checkpoint at specified global_step.\n\n  Args:\n    global_step: Finds this checkpoint if specified.\n    missing_ok: if True and checkpoints do not exist, returns without restoration.\n\n  \"\"\"\n  checkpoints = get_checkpoints(save_dir)\n  if not checkpoints:\n    if not missing_ok:\n      raise Exception(f\"No checkpoints found at {save_dir}\")\n    else:\n      logging.info(f\"No checkpoints found for restoration at {save_dir}.\")\n      return \"\"\n\n  if global_step is None:\n    return checkpoints[-1]\n\n  logging.info(f\"Found checkpoints: {checkpoints}\")\n  for checkpoint in checkpoints:\n    step = step_from_checkpoint(checkpoint)\n    if global_step == step:\n      chosen_checkpoint = checkpoint\n      break\n  else:\n    raise Exception(f\"Desired checkpoint at {global_step} not found in {save_dir}\")\n  return chosen_checkpoint\n\n\ndef get_checkpoints(save_dir: str) -> List[str]:\n  \"\"\"Gets all checkpoints that have been fully written.\"\"\"\n  checkpoints = []\n  fs = infer_fs(save_dir)\n  if fs.exists(save_dir):\n    prefix = GCS_PREFIX if is_gcs_fs(fs) else \"\"\n    checkpoints = list(f\"{prefix}{elem}\" for elem in fs.ls(save_dir, detail=False))\n    # Only take checkpoints that were fully written.\n    checkpoints = list(\n      filter(\n        lambda path: fs.exists(f\"{path}/{torchsnapshot.snapshot.SNAPSHOT_METADATA_FNAME}\"),\n        checkpoints,\n      )\n    )\n    checkpoints = sorted(checkpoints, key=lambda path: int(os.path.basename(path)))\n  return checkpoints\n\n\ndef wait_for_evaluators(\n  save_dir: str,\n  partition_names: List[str],\n  global_step: int,\n  timeout: int,\n) -> None:\n  logging.info(\"Waiting for all evaluators to finish.\")\n  start_time = time.time()\n\n  for checkpoint in checkpoints_iterator(save_dir):\n    step = step_from_checkpoint(checkpoint)\n    logging.info(f\"Considering checkpoint {checkpoint} for global step {global_step}.\")\n    if step == global_step:\n      while partition_names:\n        if is_done_eval(checkpoint, partition_names[-1]):\n          logging.info(\n            f\"Checkpoint {checkpoint} marked as finished eval for partition {partition_names[-1]} at step {step}, still waiting for {partition_names}.\"\n          )\n          partition_names.pop()\n\n        if time.time() - start_time >= timeout:\n          logging.warning(\n            f\"Not all evaluators finished after waiting for {time.time() - start_time}\"\n          )\n          return\n        time.sleep(10)\n      logging.info(\"All evaluators finished.\")\n      return\n\n    if time.time() - start_time >= timeout:\n      logging.warning(f\"Not all evaluators finished after waiting for {time.time() - start_time}\")\n      return\n"
  },
  {
    "path": "common/device.py",
    "content": "import os\n\nimport torch\nimport torch.distributed as dist\n\n\ndef maybe_setup_tensorflow():\n  try:\n    import tensorflow as tf\n  except ImportError:\n    pass\n  else:\n    tf.config.set_visible_devices([], \"GPU\")  # disable tf gpu\n\n\ndef setup_and_get_device(tf_ok: bool = True) -> torch.device:\n  if tf_ok:\n    maybe_setup_tensorflow()\n\n  device = torch.device(\"cpu\")\n  backend = \"gloo\"\n  if torch.cuda.is_available():\n    rank = os.environ[\"LOCAL_RANK\"]\n    device = torch.device(f\"cuda:{rank}\")\n    backend = \"nccl\"\n    torch.cuda.set_device(device)\n  if not torch.distributed.is_initialized():\n    dist.init_process_group(backend)\n\n  return device\n"
  },
  {
    "path": "common/filesystem/__init__.py",
    "content": "from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs\n"
  },
  {
    "path": "common/filesystem/test_infer_fs.py",
    "content": "\"\"\"Minimal test for infer_fs.\n\nMostly a test that it returns an object\n\"\"\"\nfrom tml.common.filesystem import infer_fs\n\n\ndef test_infer_fs():\n  local_path = \"/tmp/local_path\"\n  gcs_path = \"gs://somebucket/somepath\"\n\n  local_fs = infer_fs(local_path)\n  gcs_fs = infer_fs(gcs_path)\n\n  # This should return two different objects\n  assert local_fs != gcs_fs\n"
  },
  {
    "path": "common/filesystem/util.py",
    "content": "\"\"\"Utilities for interacting with the file systems.\"\"\"\nfrom fsspec.implementations.local import LocalFileSystem\nimport gcsfs\n\n\nGCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)\nLOCAL_FS = LocalFileSystem()\n\n\ndef infer_fs(path: str):\n  if path.startswith(\"gs://\"):\n    return GCS_FS\n  elif path.startswith(\"hdfs://\"):\n    # We can probably use pyarrow HDFS to support this.\n    raise NotImplementedError(\"HDFS not yet supported\")\n  else:\n    return LOCAL_FS\n\n\ndef is_local_fs(fs):\n  return fs == LOCAL_FS\n\n\ndef is_gcs_fs(fs):\n  return fs == GCS_FS\n"
  },
  {
    "path": "common/log_weights.py",
    "content": "\"\"\"For logging model weights.\"\"\"\nimport itertools\nfrom typing import Callable, Dict, List, Optional, Union\n\nfrom tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]\nimport torch\nimport torch.distributed as dist\nfrom torchrec.distributed.model_parallel import DistributedModelParallel\n\n\ndef weights_to_log(\n  model: torch.nn.Module,\n  how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,\n):\n  \"\"\"Creates dict of reduced weights to log to give sense of training.\n\n  Args:\n    model: model to traverse.\n    how_to_log: if a function, then applies this to every parameter, if a dict\n      then only applies and logs specified parameters.\n\n  \"\"\"\n  if not how_to_log:\n    return\n\n  to_log = dict()\n  named_parameters = model.named_parameters()\n  logging.info(f\"Using DMP: {isinstance(model, DistributedModelParallel)}\")\n  if isinstance(model, DistributedModelParallel):\n    named_parameters = itertools.chain(\n      named_parameters, model._dmp_wrapped_module.named_parameters()\n    )\n  logging.info(\n    f\"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}\"\n  )\n  for param_name, params in named_parameters:\n    if callable(how_to_log):\n      how = how_to_log\n    else:\n      how = how_to_log.get(param_name)  # type: ignore[assignment]\n    if not how:\n      continue  # type: ignore\n    to_log[f\"model/{how.__name__}/{param_name}\"] = how(params.detach()).cpu().numpy()\n  return to_log\n\n\ndef log_ebc_norms(\n  model_state_dict,\n  ebc_keys: List[str],\n  sample_size: int = 4_000_000,\n) -> Dict[str, torch.Tensor]:\n  \"\"\"Logs the norms of the embedding tables as specified by ebc_keys.\n  As of now, log average norm per rank.\n\n  Args:\n      model_state_dict: model.state_dict()\n      ebc_keys: list of embedding keys from state_dict to log. Must contain full name,\n      i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight\n      sample_size: Limits number of rows per rank to compute average on to avoid OOM.\n  \"\"\"\n  norm_logs = dict()\n  for emb_key in ebc_keys:\n    norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f\"cuda:{dist.get_rank()}\"))\n    if emb_key in model_state_dict:\n      emb_weight = model_state_dict[emb_key]\n      try:\n        emb_weight_tensor = emb_weight.local_tensor()\n      except AttributeError as e:\n        logging.info(e)\n        emb_weight_tensor = emb_weight\n      logging.info(\"Running Tensor.detach()\")\n      emb_weight_tensor = emb_weight_tensor.detach()\n      sample_mask = torch.randperm(emb_weight_tensor.shape[0])[\n        : min(sample_size, emb_weight_tensor.shape[0])\n      ]\n      # WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks\n      # Change sample_size if the you observe frequent OOM errors or remove weight logging.\n      norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)\n      logging.info(f\"Norm shape before reduction: {norms.shape}\", rank=-1)\n      norms = norms.mean().to(torch.device(f\"cuda:{dist.get_rank()}\"))\n\n    all_norms = [\n      torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())\n    ]\n    dist.all_gather(all_norms, norms)\n    for idx, norm in enumerate(all_norms):\n      if norm != -1.0:\n        norm_logs[f\"{emb_key}-norm-{idx}\"] = norm\n  logging.info(f\"Norm Logs are {norm_logs}\")\n  return norm_logs\n"
  },
  {
    "path": "common/modules/embedding/config.py",
    "content": "from typing import List\nfrom enum import Enum\n\nimport tml.core.config as base_config\nfrom tml.optimizers.config import OptimizerConfig\n\nimport pydantic\n\n\nclass DataType(str, Enum):\n  FP32 = \"fp32\"\n  FP16 = \"fp16\"\n\n\nclass EmbeddingSnapshot(base_config.BaseConfig):\n  \"\"\"Configuration for Embedding snapshot\"\"\"\n\n  emb_name: str = pydantic.Field(\n    ..., description=\"Name of the embedding table from the loaded snapshot\"\n  )\n  embedding_snapshot_uri: str = pydantic.Field(\n    ..., description=\"Path to torchsnapshot of the embedding\"\n  )\n\n\nclass EmbeddingBagConfig(base_config.BaseConfig):\n  \"\"\"Configuration for EmbeddingBag.\"\"\"\n\n  name: str = pydantic.Field(..., description=\"name of embedding bag\")\n  num_embeddings: int = pydantic.Field(..., description=\"size of embedding dictionary\")\n  embedding_dim: int = pydantic.Field(..., description=\"size of each embedding vector\")\n  pretrained: EmbeddingSnapshot = pydantic.Field(None, description=\"Snapshot properties\")\n  vocab: str = pydantic.Field(\n    None, description=\"Directory to parquet files of mapping from entity ID to table index.\"\n  )\n  # make sure to use an optimizer that matches:\n  # https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15\n  optimizer: OptimizerConfig\n  data_type: DataType\n\n\nclass LargeEmbeddingsConfig(base_config.BaseConfig):\n  \"\"\"Configuration for EmbeddingBagCollection.\n\n  The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.\n  \"\"\"\n\n  tables: List[EmbeddingBagConfig] = pydantic.Field(..., description=\"list of embedding tables\")\n  tables_to_log: List[str] = pydantic.Field(\n    None, description=\"list of embedding table names that we want to log during training\"\n  )\n\n\nclass Mode(str, Enum):\n  \"\"\"Job modes.\"\"\"\n\n  TRAIN = \"train\"\n  EVALUATE = \"evaluate\"\n  INFERENCE = \"inference\"\n"
  },
  {
    "path": "common/modules/embedding/embedding.py",
    "content": "from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType\nfrom tml.ml_logging.torch_logging import logging\n\nimport torch\nfrom torch import nn\nimport torchrec\nfrom torchrec.modules import embedding_configs\nfrom torchrec import EmbeddingBagConfig, EmbeddingBagCollection\nfrom torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor\nimport numpy as np\n\n\nclass LargeEmbeddings(nn.Module):\n  def __init__(\n    self,\n    large_embeddings_config: LargeEmbeddingsConfig,\n  ):\n    super().__init__()\n\n    tables = []\n    for table in large_embeddings_config.tables:\n      data_type = (\n        embedding_configs.DataType.FP32\n        if (table.data_type == DataType.FP32)\n        else embedding_configs.DataType.FP16\n      )\n\n      tables.append(\n        EmbeddingBagConfig(\n          embedding_dim=table.embedding_dim,\n          feature_names=[table.name],  # restricted to 1 feature per table for now\n          name=table.name,\n          num_embeddings=table.num_embeddings,\n          pooling=torchrec.PoolingType.SUM,\n          data_type=data_type,\n        )\n      )\n\n    self.ebc = EmbeddingBagCollection(\n      device=\"meta\",\n      tables=tables,\n    )\n\n    logging.info(\"********************** EBC named params are **********\")\n    logging.info(list(self.ebc.named_parameters()))\n\n    # This hook is used to perform post-processing surgery\n    # on large_embedding models to prep them for serving\n    self.surgery_cut_point = torch.nn.Identity()\n\n  def forward(\n    self,\n    sparse_features: KeyedJaggedTensor,\n  ) -> KeyedTensor:\n    pooled_embs = self.ebc(sparse_features)\n\n    # a KeyedTensor\n    return self.surgery_cut_point(pooled_embs)\n"
  },
  {
    "path": "common/run_training.py",
    "content": "import os\nimport subprocess\nimport sys\nfrom typing import Optional\n\nfrom tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]\nfrom twitter.ml.tensorflow.experimental.distributed import utils\n\nimport torch\nimport torch.distributed.run\n\n\ndef is_distributed_worker():\n  world_size = os.environ.get(\"WORLD_SIZE\", None)\n  rank = os.environ.get(\"RANK\", None)\n  return world_size is not None and rank is not None\n\n\ndef maybe_run_training(\n  train_fn,\n  module_name,\n  nproc_per_node: Optional[int] = None,\n  num_nodes: Optional[int] = None,\n  set_python_path_in_subprocess: bool = False,\n  is_chief: Optional[bool] = False,\n  **training_kwargs,\n):\n  \"\"\"Wrapper function for single node, multi-GPU Pytorch training.\n\n  If the necessary distributed Pytorch environment variables\n  (WORLD_SIZE, RANK) have been set, then this function executes\n  `train_fn(**training_kwargs)`.\n\n  Otherwise, this function calls torchrun and points at the calling module\n  `module_name`.  After this call, the necessary environment variables are set\n  and training will commence.\n\n  Args:\n    train_fn:  The function that is responsible for training\n    module_name:  The name of the module that this function was called from;\n       used to indicate torchrun entrypoint.\n    nproc_per_node: Number of workers per node; supported values.\n    num_nodes: Number of nodes, otherwise inferred from environment.\n    is_chief: If process is running on chief.\n    set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.\n  \"\"\"\n\n  machines = utils.machine_from_env()\n  if num_nodes is None:\n    num_nodes = 1\n    if machines.num_workers:\n      num_nodes += machines.num_workers\n\n  if is_distributed_worker():\n    # world_size, rank, etc are set; assuming any other env vars are set (checks to come)\n    # start the actual training!\n    train_fn(**training_kwargs)\n  else:\n    if nproc_per_node is None:\n      if torch.cuda.is_available():\n        nproc_per_node = torch.cuda.device_count()\n      else:\n        nproc_per_node = machines.chief.num_accelerators\n\n    # Rejoin all arguments to send back through torchrec\n    # this is a temporary measure, will replace the os.system call\n    # with torchrun API calls\n    args = list(f\"--{key}={val}\" for key, val in training_kwargs.items())\n\n    cmd = [\n      \"--nnodes\",\n      str(num_nodes),\n    ]\n    if nproc_per_node:\n      cmd.extend([\"--nproc_per_node\", str(nproc_per_node)])\n    if num_nodes > 1:\n      cluster_resolver = utils.cluster_resolver()\n      backend_address = cluster_resolver.cluster_spec().task_address(\"chief\", 0)\n      cmd.extend(\n        [\n          \"--rdzv_backend\",\n          \"c10d\",\n          \"--rdzv_id\",\n          backend_address,\n        ]\n      )\n      # Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388\n      if is_chief:\n        cmd.extend([\"--rdzv_endpoint\", \"localhost:2222\"])\n      else:\n        cmd.extend([\"--rdzv_endpoint\", backend_address])\n    else:\n      cmd.append(\"--standalone\")\n\n    cmd.extend(\n      [\n        str(module_name),\n        *args,\n      ]\n    )\n    logging.info(f\"\"\"Distributed running with cmd: '{\" \".join(cmd)}'\"\"\")\n\n    # Call torchrun on this module;  will spawn new processes and re-run this\n    # function, eventually calling \"train_fn\". The following line sets the PYTHONPATH to accommodate\n    # bazel stubbing for the main binary.\n    if set_python_path_in_subprocess:\n      subprocess.run([\"torchrun\"] + cmd, env={**os.environ, \"PYTHONPATH\": \":\".join(sys.path)})\n    else:\n      torch.distributed.run.main(cmd)\n"
  },
  {
    "path": "common/test_device.py",
    "content": "\"\"\"Minimal test for device.\n\nMostly a test that this can be imported properly even tho moved.\n\"\"\"\nfrom unittest.mock import patch\n\nimport tml.common.device as device_utils\n\n\ndef test_device():\n  with patch(\"tml.common.device.dist.init_process_group\"):\n    device = device_utils.setup_and_get_device(tf_ok=False)\n  assert device.type == \"cpu\"\n"
  },
  {
    "path": "common/testing_utils.py",
    "content": "from contextlib import contextmanager\nimport datetime\nimport os\nfrom unittest.mock import patch\n\nimport torch.distributed as dist\nfrom tml.ml_logging.torch_logging import logging\n\n\nMOCK_ENV = {\n  \"LOCAL_RANK\": \"0\",\n  \"WORLD_SIZE\": \"1\",\n  \"LOCAL_WORLD_SIZE\": \"1\",\n  \"MASTER_ADDR\": \"localhost\",\n  \"MASTER_PORT\": \"29501\",\n  \"RANK\": \"0\",\n}\n\n\n@contextmanager\ndef mock_pg():\n  with patch.dict(os.environ, MOCK_ENV):\n    try:\n      dist.init_process_group(\n        backend=\"gloo\",\n        timeout=datetime.timedelta(1),\n      )\n      yield\n    except:\n      dist.destroy_process_group()\n      raise\n    finally:\n      dist.destroy_process_group()\n"
  },
  {
    "path": "common/utils.py",
    "content": "import yaml\nimport getpass\nimport os\nimport string\nfrom typing import Tuple, Type, TypeVar\n\nfrom tml.core.config import base_config\n\nimport fsspec\n\nC = TypeVar(\"C\", bound=base_config.BaseConfig)\n\n\ndef _read_file(f):\n  with fsspec.open(f) as f:\n    return f.read()\n\n\ndef setup_configuration(\n  config_type: Type[C],\n  yaml_path: str,\n  substitute_env_variable: bool = False,\n) -> Tuple[C, str]:\n  \"\"\"Resolves a config at a yaml path.\n\n  Args:\n    config_type: Pydantic config class to load.\n    yaml_path: yaml path of the config file.\n    substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their\n    environment variable value whenever possible. If an environment variable doesn't exist,\n    the string is left unchanged.\n\n  Returns:\n    The pydantic config object.\n  \"\"\"\n\n  def _substitute(s):\n    if substitute_env_variable:\n      return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())\n    return s\n\n  assert config_type is not None, \"can't use all_config without config_type\"\n  content = _substitute(yaml.safe_load(_read_file(yaml_path)))\n  return config_type.parse_obj(content)\n"
  },
  {
    "path": "common/wandb.py",
    "content": "from typing import Any, Dict, List\n\nimport tml.core.config as base_config\n\nimport pydantic\n\n\nclass WandbConfig(base_config.BaseConfig):\n  host: str = pydantic.Field(\n    \"https://https--wandb--prod--wandb.service.qus1.twitter.biz/\",\n    description=\"Host of Weights and Biases instance, passed to login.\",\n  )\n  key_path: str = pydantic.Field(description=\"Path to key file.\")\n\n  name: str = pydantic.Field(None, description=\"Name of the experiment, passed to init.\")\n  entity: str = pydantic.Field(None, description=\"Name of user/service account, passed to init.\")\n  project: str = pydantic.Field(None, description=\"Name of wandb project, passed to init.\")\n  tags: List[str] = pydantic.Field([], description=\"List of tags, passed to init.\")\n  notes: str = pydantic.Field(None, description=\"Notes, passed to init.\")\n  metadata: Dict[str, Any] = pydantic.Field(None, description=\"Additional metadata to log.\")\n"
  },
  {
    "path": "core/__init__.py",
    "content": ""
  },
  {
    "path": "core/config/__init__.py",
    "content": "from tml.core.config.base_config import BaseConfig\nfrom tml.core.config.config_load import load_config_from_yaml\n\n# Make mypy happy by explicitly rexporting the symbols intended for end user use.\n__all__ = [\"BaseConfig\", \"load_config_from_yaml\"]\n"
  },
  {
    "path": "core/config/base_config.py",
    "content": "\"\"\"Base class for all config (forbids extra fields).\"\"\"\n\nimport collections\nimport functools\nimport yaml\n\nimport pydantic\n\n\nclass BaseConfig(pydantic.BaseModel):\n  \"\"\"Base class for all derived config classes.\n\n  This class provides some convenient functionality:\n    - Disallows extra fields when constructing an object. User error\n      should be reduced by exact arguments.\n    - \"one_of\" fields. A subclass can group optional fields and enforce\n      that only one of the fields be set. For example:\n\n      ```\n      class ExampleConfig(BaseConfig):\n        x: int = Field(None, one_of=\"group_1\")\n        y: int = Field(None, one_of=\"group_1\")\n\n      ExampleConfig(x=1) # ok\n      ExampleConfig(y=1) # ok\n      ExampleConfig(x=1, y=1) # throws error\n      ```\n  \"\"\"\n\n  class Config:\n    \"\"\"Forbids extras.\"\"\"\n\n    extra = pydantic.Extra.forbid  # noqa\n\n  @classmethod\n  @functools.lru_cache()\n  def _field_data_map(cls, field_data_name):\n    \"\"\"Create a map of fields with provided the field data.\"\"\"\n    schema = cls.schema()\n    one_of = collections.defaultdict(list)\n    for field, fdata in schema[\"properties\"].items():\n      if field_data_name in fdata:\n        one_of[fdata[field_data_name]].append(field)\n    return one_of\n\n  @pydantic.root_validator\n  def _one_of_check(cls, values):\n    \"\"\"Validate that all 'one of' fields are appear exactly once.\"\"\"\n    one_of_map = cls._field_data_map(\"one_of\")\n    for one_of, field_names in one_of_map.items():\n      if sum([values.get(n, None) is not None for n in field_names]) != 1:\n        raise ValueError(f\"Exactly one of {','.join(field_names)} required.\")\n    return values\n\n  @pydantic.root_validator\n  def _at_most_one_of_check(cls, values):\n    \"\"\"Validate that all 'at_most_one_of' fields appear at most once.\"\"\"\n    at_most_one_of_map = cls._field_data_map(\"at_most_one_of\")\n    for one_of, field_names in at_most_one_of_map.items():\n      if sum([values.get(n, None) is not None for n in field_names]) > 1:\n        raise ValueError(f\"At most one of {','.join(field_names)} can be set.\")\n    return values\n\n  def pretty_print(self) -> str:\n    \"\"\"Return a human legible (yaml) representation of the config useful for logging.\"\"\"\n    return yaml.dump(self.dict())\n"
  },
  {
    "path": "core/config/base_config_test.py",
    "content": "from unittest import TestCase\n\nfrom tml.core.config import BaseConfig\n\nimport pydantic\n\n\nclass BaseConfigTest(TestCase):\n  def test_extra_forbidden(self):\n    class Config(BaseConfig):\n      x: int\n\n    Config(x=1)\n    with self.assertRaises(pydantic.ValidationError):\n      Config(x=1, y=2)\n\n  def test_one_of(self):\n    class Config(BaseConfig):\n      x: int = pydantic.Field(None, one_of=\"f\")\n      y: int = pydantic.Field(None, one_of=\"f\")\n\n    with self.assertRaises(pydantic.ValidationError):\n      Config()\n    Config(x=1)\n    Config(y=1)\n    with self.assertRaises(pydantic.ValidationError):\n      Config(x=1, y=3)\n\n  def test_at_most_one_of(self):\n    class Config(BaseConfig):\n      x: int = pydantic.Field(None, at_most_one_of=\"f\")\n      y: str = pydantic.Field(None, at_most_one_of=\"f\")\n\n    Config()\n    Config(x=1)\n    Config(y=\"a\")\n    with self.assertRaises(pydantic.ValidationError):\n      Config(x=1, y=\"a\")\n"
  },
  {
    "path": "core/config/config_load.py",
    "content": "import yaml\nimport string\nimport getpass\nimport os\nfrom typing import Type\n\nfrom tml.core.config.base_config import BaseConfig\n\n\ndef load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):\n  \"\"\"Recommend method to load a config file (a yaml file) and parse it.\n\n  Because we have a shared filesystem the recommended route to running jobs it put modified config\n  files with the desired parameters somewhere on the filesytem and run jobs pointing to them.\n  \"\"\"\n\n  def _substitute(s):\n    return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())\n\n  with open(yaml_path, \"r\") as f:\n    raw_contents = f.read()\n    obj = yaml.safe_load(_substitute(raw_contents))\n\n  return config_type.parse_obj(obj)\n"
  },
  {
    "path": "core/config/test_config_load.py",
    "content": "from unittest import TestCase\n\nfrom tml.core.config import BaseConfig, load_config_from_yaml\n\nimport pydantic\nimport getpass\nimport pydantic\n\n\nclass _PointlessConfig(BaseConfig):\n  a: int\n  user: str\n\n\ndef test_load_config_from_yaml(tmp_path):\n  yaml_path = tmp_path.joinpath(\"test.yaml\").as_posix()\n  with open(yaml_path, \"w\") as yaml_file:\n    yaml_file.write(\"\"\"a: 3\\nuser: ${USER}\\n\"\"\")\n\n  pointless_config = load_config_from_yaml(_PointlessConfig, yaml_path)\n\n  assert pointless_config.a == 3\n  assert pointless_config.user == getpass.getuser()\n"
  },
  {
    "path": "core/config/training.py",
    "content": "from typing import Any, Dict, List, Optional\n\nfrom tml.common.wandb import WandbConfig\nfrom tml.core.config import base_config\nfrom tml.projects.twhin.data.config import TwhinDataConfig\nfrom tml.projects.twhin.models.config import TwhinModelConfig\n\nimport pydantic\n\n\nclass RuntimeConfig(base_config.BaseConfig):\n  wandb: WandbConfig = pydantic.Field(None)\n  enable_tensorfloat32: bool = pydantic.Field(\n    False, description=\"Use tensorfloat32 if on Ampere devices.\"\n  )\n  enable_amp: bool = pydantic.Field(False, description=\"Enable automatic mixed precision.\")\n\n\nclass TrainingConfig(base_config.BaseConfig):\n  save_dir: str = pydantic.Field(\"/tmp/model\", description=\"Directory to save checkpoints.\")\n  num_train_steps: pydantic.PositiveInt = 10000\n  initial_checkpoint_dir: str = pydantic.Field(\n    None, description=\"Directory of initial checkpoints\", at_most_one_of=\"initialization\"\n  )\n  checkpoint_every_n: pydantic.PositiveInt = 1000\n  checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Maximum number of checkpoints to keep. Defaults to keeping all.\"\n  )\n  train_log_every_n: pydantic.PositiveInt = 1000\n  num_eval_steps: int = pydantic.Field(\n    16384, description=\"Number of evaluation steps. If < 0 the entire dataset will be used.\"\n  )\n  eval_log_every_n: pydantic.PositiveInt = 5000\n\n  eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60\n\n  gradient_accumulation: int = pydantic.Field(\n    None, description=\"Number of replica steps to accumulate gradients.\"\n  )\n  num_epochs: pydantic.PositiveInt = 1\n"
  },
  {
    "path": "core/custom_training_loop.py",
    "content": "\"\"\"Torch and torchrec specific training and evaluation loops.\n\nFeatures (go/100_enablements):\n    - CUDA data-fetch, compute, gradient-push overlap\n    - Large learnable embeddings through torchrec\n    - On/off-chief evaluation\n    - Warmstart/checkpoint management\n    - go/dataset-service 0-copy integration\n\n\"\"\"\nimport datetime\nimport os\nfrom typing import Callable, Dict, Iterable, List, Mapping, Optional\n\n\nfrom tml.common import log_weights\nimport tml.common.checkpointing.snapshot as snapshot_lib\nfrom tml.core.losses import get_global_loss_detached\nfrom tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]\nfrom tml.core.train_pipeline import TrainPipelineSparseDist\n\nimport tree\nimport torch\nimport torch.distributed as dist\nfrom torch.optim.lr_scheduler import _LRScheduler\nimport torchmetrics as tm\n\n\ndef get_new_iterator(iterable: Iterable):\n  \"\"\"\n  This obtain a new iterator from the iterable. If the iterable uses tf.data.Dataset internally,\n   getting a new iterator each N steps will avoid memory leak. To avoid the memory leak\n   calling iter(iterable) should return a \"fresh\" iterator using a fresh\n   (new instance of) tf.data.Iterator.\n   In particular, iterable can be a torch.utils.data.IterableDataset or a\n   torch.utils.data.DataLoader.\n\n  When using DDS, performing this reset does not change the order in which elements are received\n   (excluding elements already prefetched) provided that iter(iterable) internally uses\n   a new instance of tf.data.Dataset created by calling from_dataset_id.\n   This requirement is satisfied by RecapDataset.\n  :param iterable:\n  :return:\n  \"\"\"\n  return iter(iterable)\n\n\ndef _get_step_fn(pipeline, data_iterator, training: bool):\n  def step_fn():\n    # It turns out that model.train() and model.eval() simply switch a single field inside the model\n    # class,so it's somewhat safer to wrap in here.\n    if training:\n      pipeline._model.train()\n    else:\n      pipeline._model.eval()\n\n    outputs = pipeline.progress(data_iterator)\n    return tree.map_structure(lambda elem: elem.detach(), outputs)\n\n  return step_fn\n\n\n@torch.no_grad()\ndef _run_evaluation(\n  pipeline,\n  dataset,\n  eval_steps: int,\n  metrics: tm.MetricCollection,\n  eval_batch_size: int,\n  logger=None,\n):\n  \"\"\"Runs the evaluation loop over all evaluation iterators.\"\"\"\n  dataset = get_new_iterator(dataset)\n  step_fn = _get_step_fn(pipeline, dataset, training=False)\n  last_time = datetime.datetime.now()\n  logging.info(f\"Starting {eval_steps} steps of evaluation.\")\n  for _ in range(eval_steps):\n    outputs = step_fn()\n    metrics.update(outputs)\n  eval_ex_per_s = (\n    eval_batch_size * eval_steps / (datetime.datetime.now() - last_time).total_seconds()\n  )\n  logging.info(f\"eval examples_per_s : {eval_ex_per_s}\")\n  metrics_result = metrics.compute()\n  # Resetting at end to release metrics memory not in use.\n  # Reset metrics to prevent accumulation between multiple evaluation splits and not report a\n  # running average.\n  metrics.reset()\n  return metrics_result\n\n\ndef train(\n  model: torch.nn.Module,\n  optimizer: torch.optim.Optimizer,\n  device: str,\n  save_dir: str,\n  logging_interval: int,\n  train_steps: int,\n  checkpoint_frequency: int,\n  dataset: Iterable,\n  worker_batch_size: int,\n  num_workers: Optional[int] = 0,\n  enable_amp: bool = False,\n  initial_checkpoint_dir: Optional[str] = None,\n  gradient_accumulation: Optional[int] = None,\n  logger_initializer: Optional[Callable] = None,\n  scheduler: _LRScheduler = None,\n  metrics: Optional[tm.MetricCollection] = None,\n  parameters_to_log: Optional[Dict[str, Callable]] = None,\n  tables_to_log: Optional[List[str]] = None,\n) -> None:\n  \"\"\"Runs training and eval on the given TrainPipeline\n\n  Args:\n    dataset: data iterator for the training set\n    evaluation_iterators: data iterators for the different evaluation sets\n    scheduler: optional learning rate scheduler\n    output_transform_for_metrics: optional transformation functions to transorm the model\n                                  output and labels into a format the metrics can understand\n  \"\"\"\n\n  train_pipeline = TrainPipelineSparseDist(\n    model=model,\n    optimizer=optimizer,\n    device=device,\n    enable_amp=enable_amp,\n    grad_accum=gradient_accumulation,\n  )  # type: ignore[var-annotated]\n\n  # We explicitly initialize optimizer state here so that checkpoint will work properly.\n  if hasattr(train_pipeline._optimizer, \"init_state\"):\n    train_pipeline._optimizer.init_state()\n\n  save_state = {\n    \"model\": train_pipeline._model,\n    \"optimizer\": train_pipeline._optimizer,\n    \"scaler\": train_pipeline._grad_scaler,\n  }\n\n  chosen_checkpoint = None\n  checkpoint_handler = snapshot_lib.Snapshot(\n    save_dir=save_dir,\n    state=save_state,\n  )\n\n  if save_dir:\n    chosen_checkpoint = snapshot_lib.get_checkpoint(save_dir=save_dir, missing_ok=True)\n\n  start_step = 0\n  start_walltime = 0.0\n  if chosen_checkpoint:\n    # Skip restoration and exit if we should be finished.\n    chosen_checkpoint_global_step = snapshot_lib.step_from_checkpoint(chosen_checkpoint)\n    if not chosen_checkpoint_global_step < dist.get_world_size() * train_steps:\n      logging.info(\n        \"Not restoring and finishing training as latest checkpoint \"\n        f\"{chosen_checkpoint} found \"\n        f\"at global_step ({chosen_checkpoint_global_step}) >= \"\n        f\"train_steps ({dist.get_world_size() * train_steps})\"\n      )\n      return\n    logging.info(f\"Restoring latest checkpoint from global_step {chosen_checkpoint_global_step}\")\n    checkpoint_handler.restore(chosen_checkpoint)\n    start_step = checkpoint_handler.step\n    start_walltime = checkpoint_handler.walltime\n  elif initial_checkpoint_dir:\n    base, ckpt_step = os.path.split(initial_checkpoint_dir)\n    warmstart_handler = snapshot_lib.Snapshot(\n      save_dir=base,\n      state=save_state,\n    )\n    ckpt = snapshot_lib.get_checkpoint(save_dir=base, missing_ok=False, global_step=int(ckpt_step))\n    logging.info(\n      f\"Restoring from initial_checkpoint_dir: {initial_checkpoint_dir}, but keeping starting step as 0.\"\n    )\n    warmstart_handler.restore(ckpt)\n\n  train_logger = logger_initializer(mode=\"train\") if logger_initializer else None\n  train_step_fn = _get_step_fn(train_pipeline, get_new_iterator(dataset), training=True)\n\n  # Counting number of parameters in the model directly when creating it.\n  nb_param = 0\n  for p in model.parameters():\n    nb_param += p.numel()\n  logging.info(f\"Model has {nb_param} parameters\")\n\n  last_time = datetime.datetime.now()\n  start_time = last_time\n  last_pending_snapshot = None\n  for step in range(start_step, train_steps + 1):\n    checkpoint_handler.step = step\n    outputs = train_step_fn()\n    step_done_time = datetime.datetime.now()\n    checkpoint_handler.walltime = (step_done_time - start_time).total_seconds() + start_walltime\n\n    if scheduler:\n      scheduler.step()\n\n    if step % logging_interval == 0:\n      interval_time = (step_done_time - last_time).total_seconds()\n      steps_per_s = logging_interval / interval_time\n      worker_example_per_s = steps_per_s * worker_batch_size\n      global_example_per_s = worker_example_per_s * (1 + (num_workers or 0))\n      global_step = step\n\n      log_values = {\n        \"global_step\": global_step,\n        \"loss\": get_global_loss_detached(outputs[\"loss\"]),\n        \"steps_per_s\": steps_per_s,\n        \"global_example_per_s\": global_example_per_s,\n        \"worker_examples_per_s\": worker_example_per_s,\n        \"active_training_walltime\": checkpoint_handler.walltime,\n      }\n      if parameters_to_log:\n        log_values.update(\n          log_weights.weights_to_log(\n            model=model,\n            how_to_log=parameters_to_log,\n          )\n        )\n      log_values = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), log_values)\n\n      if tables_to_log:\n        log_values.update(\n          log_weights.log_ebc_norms(\n            model_state_dict=train_pipeline._model.state_dict(),\n            ebc_keys=tables_to_log,\n          )\n        )\n      if train_logger:\n        train_logger.log(log_values, step=global_step)\n      log_line = \", \".join(f\"{name}: {value}\" for name, value in log_values.items())\n      logging.info(f\"Step: {step}, training. {log_line}\")\n      last_time = step_done_time\n\n      # If we just restored, do not save again.\n      if checkpoint_frequency and step > start_step and step % checkpoint_frequency == 0:\n        if last_pending_snapshot and not last_pending_snapshot.done():\n          logging.warning(\n            \"Begin a new snapshot and the last one hasn't finished. That probably indicates \"\n            \"either you're snapshotting really often or something is wrong. Will now block and \"\n            \"wait for snapshot to finish before beginning the next one.\"\n          )\n          last_pending_snapshot.wait()\n        last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())\n\n  # Save if we did not just save.\n  if checkpoint_frequency and step % checkpoint_frequency != 0:\n    # For the final save, wait for the checkpoint to write to make sure the process doesn't finish\n    # before its completed.\n    last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())\n  logging.info(f\"Finished training steps: {step}, global_steps: {step * dist.get_world_size()}\")\n\n  if last_pending_snapshot:\n    logging.info(f\"Waiting for any checkpoints to finish.\")\n    last_pending_snapshot.wait()\n\n\ndef log_eval_results(\n  results,\n  eval_logger,\n  partition_name: str,\n  step: int,\n):\n  results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)\n  logging.info(f\"Step: {step}, evaluation ({partition_name}).\")\n  for metric_name, metric_value in results.items():\n    logging.info(f\"\\t{metric_name}: {metric_value:1.4e}\")\n\n  if eval_logger:\n    eval_logger.log(results, step=step, commit=True)\n\n\ndef only_evaluate(\n  model: torch.nn.Module,\n  optimizer: torch.optim.Optimizer,\n  device: str,\n  save_dir: str,\n  num_train_steps: int,\n  dataset: Iterable,\n  eval_batch_size: int,\n  num_eval_steps: int,\n  eval_timeout_in_s: int,\n  eval_logger: Callable,\n  partition_name: str,\n  metrics: Optional[tm.MetricCollection] = None,\n):\n  logging.info(f\"Evaluating on partition {partition_name}.\")\n  logging.info(\"Computing metrics:\")\n  logging.info(metrics)\n  eval_pipeline = TrainPipelineSparseDist(model, optimizer, device)  # type: ignore[var-annotated]\n  save_state = {\n    \"model\": eval_pipeline._model,\n    \"optimizer\": eval_pipeline._optimizer,\n  }\n  checkpoint_handler = snapshot_lib.Snapshot(\n    save_dir=save_dir,\n    state=save_state,\n  )\n  for checkpoint_path in snapshot_lib.checkpoints_iterator(save_dir, timeout=eval_timeout_in_s):\n    checkpoint_handler.restore(checkpoint_path)\n    step = checkpoint_handler.step\n    dataset = get_new_iterator(dataset)\n    results = _run_evaluation(\n      pipeline=eval_pipeline,\n      dataset=dataset,\n      eval_steps=num_eval_steps,\n      eval_batch_size=eval_batch_size,\n      metrics=metrics,\n    )\n    log_eval_results(results, eval_logger, partition_name, step=step)\n    rank = dist.get_rank() if dist.is_initialized() else 0\n    if rank == 0:\n      snapshot_lib.mark_done_eval(checkpoint_path, partition_name)\n    if step >= num_train_steps:\n      return\n"
  },
  {
    "path": "core/debug_training_loop.py",
    "content": "\"\"\"This is a very limited feature training loop useful for interactive debugging.\n\nIt is not intended for actual model tranining (it is not fast, doesn't compile the model).\nIt does not support checkpointing.\n\nsuggested use:\n\nfrom tml.core import debug_training_loop\ndebug_training_loop.train(...)\n\"\"\"\n\nfrom typing import Iterable, Optional, Dict, Callable, List\nimport torch\nfrom torch.optim.lr_scheduler import _LRScheduler\nimport torchmetrics as tm\n\nfrom tml.ml_logging.torch_logging import logging\n\n\ndef train(\n  model: torch.nn.Module,\n  optimizer: torch.optim.Optimizer,\n  train_steps: int,\n  dataset: Iterable,\n  scheduler: _LRScheduler = None,\n  # Accept any arguments (to be compatible with the real training loop)\n  # but just ignore them.\n  *args,\n  **kwargs,\n) -> None:\n\n  logging.warning(\"Running debug training loop, don't use for model training.\")\n\n  data_iter = iter(dataset)\n  for step in range(0, train_steps + 1):\n    x = next(data_iter)\n    optimizer.zero_grad()\n    loss, outputs = model.forward(x)\n    loss.backward()\n    optimizer.step()\n\n    if scheduler:\n      scheduler.step()\n\n    logging.info(f\"Step {step} completed. Loss = {loss}\")\n"
  },
  {
    "path": "core/loss_type.py",
    "content": "\"\"\"Loss type enums.\"\"\"\nfrom enum import Enum\n\n\nclass LossType(str, Enum):\n  CROSS_ENTROPY = \"cross_entropy\"\n  BCE_WITH_LOGITS = \"bce_with_logits\"\n"
  },
  {
    "path": "core/losses.py",
    "content": "\"\"\"Loss functions -- including multi task ones.\"\"\"\n\nimport typing\n\nfrom tml.core.loss_type import LossType\nfrom tml.ml_logging.torch_logging import logging\n\nimport torch\n\n\ndef _maybe_warn(reduction: str):\n  \"\"\"\n  Warning for reduction different than mean.\n  \"\"\"\n  if reduction != \"mean\":\n    logging.warn(\n      f\"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,\"\n      f\"to the gradient without DDP only for mean reduction. If you need this property for\"\n      f\"the provided reduction {reduction}, it needs to be implemented.\"\n    )\n\n\ndef build_loss(\n  loss_type: LossType,\n  reduction=\"mean\",\n):\n  _maybe_warn(reduction)\n  f = _LOSS_TYPE_TO_FUNCTION[loss_type]\n\n  def loss_fn(logits, labels):\n    return f(logits, labels.type_as(logits), reduction=reduction)\n\n  return loss_fn\n\n\ndef get_global_loss_detached(local_loss, reduction=\"mean\"):\n  \"\"\"\n  Perform all_reduce to obtain the global loss function using the provided reduction.\n  :param local_loss: The local loss of the current rank.\n  :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.\n  :return: The reduced & detached global loss.\n  \"\"\"\n  if reduction != \"mean\":\n    logging.warn(\n      f\"The reduction used in this function should be the same as the one used by \"\n      f\"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately\"\n      f\"modified for reduction {reduction}.\"\n    )\n\n  if reduction not in [\"mean\", \"sum\"]:\n    raise ValueError(f\"Reduction {reduction} is currently unsupported.\")\n\n  global_loss = local_loss.detach()\n\n  if reduction == \"mean\":\n    global_loss.div_(torch.distributed.get_world_size())\n\n  torch.distributed.all_reduce(global_loss)\n  return global_loss\n\n\ndef build_multi_task_loss(\n  loss_type: LossType,\n  tasks: typing.List[str],\n  task_loss_reduction=\"mean\",\n  global_reduction=\"mean\",\n  pos_weights=None,\n):\n  _maybe_warn(global_reduction)\n  _maybe_warn(task_loss_reduction)\n  f = _LOSS_TYPE_TO_FUNCTION[loss_type]\n\n  loss_reduction_fns = {\n    \"mean\": torch.mean,\n    \"sum\": torch.sum,\n    \"min\": torch.min,\n    \"max\": torch.max,\n    \"median\": torch.median,\n  }\n\n  def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):\n    if pos_weights is None:\n      torch_weights = torch.ones([len(tasks)])\n    else:\n      torch_weights = torch.tensor(pos_weights)\n\n    losses = {}\n    for task_idx, task in enumerate(tasks):\n      task_logits = logits[:, task_idx]\n      label = labels[:, task_idx].type_as(task_logits)\n\n      loss = f(\n        task_logits,\n        label,\n        reduction=task_loss_reduction,\n        pos_weight=torch_weights[task_idx],\n        weight=weights[:, task_idx],\n      )\n      losses[f\"loss/{task}\"] = loss\n\n    losses[\"loss\"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))\n    return losses\n\n  return loss_fn\n\n\n_LOSS_TYPE_TO_FUNCTION = {\n  LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits\n}\n"
  },
  {
    "path": "core/metric_mixin.py",
    "content": "\"\"\"\nMixin that requires a transform to munge output dictionary of tensors a\nmodel produces to a form that the torchmetrics.Metric.update expects.\n\nBy unifying on our signature for `update`, we can also now use\ntorchmetrics.MetricCollection which requires all metrics have\nthe same call signature.\n\nTo use, override this with a transform that munges `outputs`\ninto a kwargs dict that the inherited metric.update accepts.\n\nHere are two examples of how to extend torchmetrics.SumMetric so that it accepts\nan output dictionary of tensors and munges it to what SumMetric expects (single `value`)\nfor its update method.\n\n1. Using as a mixin to inherit from or define a new metric class.\n\n  class Count(MetricMixin, SumMetric):\n    def transform(self, outputs):\n      return {'value': 1}\n\n2. Redefine an existing metric class.\n\n  SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})\n\n\"\"\"\nfrom abc import abstractmethod\nfrom typing import Callable, Dict, List\n\nfrom tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]\n\nimport torch\nimport torchmetrics\n\n\nclass MetricMixin:\n  @abstractmethod\n  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:\n    ...\n\n  def update(self, outputs: Dict[str, torch.Tensor]):\n    results = self.transform(outputs)\n    # Do not try to update if any tensor is empty as a result of stratification.\n    for value in results.values():\n      if torch.is_tensor(value) and not value.nelement():\n        return\n    super().update(**results)\n\n\nclass TaskMixin:\n  def __init__(self, task_idx: int = -1, **kwargs):\n    super().__init__(**kwargs)\n    self._task_idx = task_idx\n\n\nclass StratifyMixin:\n  def __init__(\n    self,\n    stratifier=None,\n    **kwargs,\n  ):\n    super().__init__(**kwargs)\n    self._stratifier = stratifier\n\n  def maybe_apply_stratification(\n    self, outputs: Dict[str, torch.Tensor], value_names: List[str]\n  ) -> Dict[str, torch.Tensor]:\n    \"\"\"Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.\"\"\"\n    outputs = outputs.copy()\n    if not self._stratifier:\n      return outputs\n    stratifiers = outputs.get(\"stratifiers\")\n    if not stratifiers:\n      return outputs\n    if stratifiers.get(self._stratifier.name) is None:\n      return outputs\n\n    mask = torch.flatten(outputs[\"stratifiers\"][self._stratifier.name] == self._stratifier.value)\n    target_slice = torch.squeeze(mask.nonzero(), -1)\n    for value_name in value_names:\n      target = outputs[value_name]\n      outputs[value_name] = torch.index_select(target, 0, target_slice)\n    return outputs\n\n\ndef prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):\n  \"\"\"Returns new class using MetricMixin and given base_metric.\n\n  Functionally the same using inheritance, just saves some lines of code\n  if no need for class attributes.\n\n  \"\"\"\n\n  def transform_method(_self, *args, **kwargs):\n    return transform(*args, **kwargs)\n\n  return type(\n    base_metric.__name__,\n    (\n      MetricMixin,\n      base_metric,\n    ),\n    {\"transform\": transform_method},\n  )\n"
  },
  {
    "path": "core/metrics.py",
    "content": "\"\"\"Common metrics that also support multi task.\n\nWe assume multi task models will output [task_idx, ...] predictions\n\n\"\"\"\nfrom typing import Any, Dict\n\nfrom tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin\n\nimport torch\nimport torchmetrics as tm\n\n\ndef probs_and_labels(\n  outputs: Dict[str, torch.Tensor],\n  task_idx: int,\n) -> Dict[str, torch.Tensor]:\n  preds = outputs[\"probabilities\"]\n  target = outputs[\"labels\"]\n  if task_idx >= 0:\n    preds = preds[:, task_idx]\n    target = target[:, task_idx]\n  return {\n    \"preds\": preds,\n    \"target\": target.int(),\n  }\n\n\nclass Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):\n  def transform(self, outputs):\n    outputs = self.maybe_apply_stratification(outputs, [\"labels\"])\n    value = outputs[\"labels\"]\n    if self._task_idx >= 0:\n      value = value[:, self._task_idx]\n    return {\"value\": value}\n\n\nclass Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):\n  def transform(self, outputs):\n    outputs = self.maybe_apply_stratification(outputs, [\"labels\"])\n    value = outputs[\"labels\"]\n    if self._task_idx >= 0:\n      value = value[:, self._task_idx]\n    return {\"value\": value}\n\n\nclass Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):\n  def transform(self, outputs):\n    outputs = self.maybe_apply_stratification(outputs, [\"probabilities\"])\n    value = outputs[\"probabilities\"]\n    if self._task_idx >= 0:\n      value = value[:, self._task_idx]\n    return {\"value\": value}\n\n\nclass Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):\n  def transform(self, outputs):\n    outputs = self.maybe_apply_stratification(outputs, [\"probabilities\", \"labels\"])\n    return probs_and_labels(outputs, self._task_idx)\n\n\nclass Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):\n  def transform(self, outputs):\n    outputs = self.maybe_apply_stratification(outputs, [\"probabilities\", \"labels\"])\n    return probs_and_labels(outputs, self._task_idx)\n\n\nclass TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):\n  def transform(self, outputs):\n    outputs = self.maybe_apply_stratification(outputs, [\"probabilities\", \"labels\"])\n    return probs_and_labels(outputs, self._task_idx)\n\n\nclass Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):\n  \"\"\"\n  Based on:\n  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420\n  \"\"\"\n\n  def __init__(self, num_samples, **kwargs):\n    super().__init__(**kwargs)\n    self.num_samples = num_samples\n\n  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:\n    scores, labels = outputs[\"logits\"], outputs[\"labels\"]\n    pos_scores = scores[labels == 1]\n    neg_scores = scores[labels == 0]\n    result = {\n      \"value\": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))]\n      > neg_scores[torch.randint(len(neg_scores), (self.num_samples,))]\n    }\n    return result\n\n\nclass PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):\n  \"\"\"\n  The ranks of all positives\n  Based on:\n  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73\n  \"\"\"\n\n  def __init__(self, **kwargs):\n    super().__init__(**kwargs)\n\n  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:\n    scores, labels = outputs[\"logits\"], outputs[\"labels\"]\n    _, sorted_indices = scores.sort(descending=True)\n    pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1  # all ranks start from 1\n    result = {\"value\": pos_ranks}\n    return result\n\n\nclass ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):\n  \"\"\"\n  The reciprocal of the ranks of all\n  Based on:\n  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74\n  \"\"\"\n\n  def __init__(self, **kwargs):\n    super().__init__(**kwargs)\n\n  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:\n    scores, labels = outputs[\"logits\"], outputs[\"labels\"]\n    _, sorted_indices = scores.sort(descending=True)\n    pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1  # all ranks start from 1\n    result = {\"value\": torch.div(torch.ones_like(pos_ranks), pos_ranks)}\n    return result\n\n\nclass HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):\n  \"\"\"\n  The fraction of positives that rank in the top K among their negatives\n  Note that this is basically precision@k\n  Based on:\n  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75\n  \"\"\"\n\n  def __init__(self, k: int, **kwargs):\n    super().__init__(**kwargs)\n    self.k = k\n\n  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:\n    scores, labels = outputs[\"logits\"], outputs[\"labels\"]\n    _, sorted_indices = scores.sort(descending=True)\n    pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1  # all ranks start from 1\n    result = {\"value\": (pos_ranks <= self.k).float()}\n    return result\n"
  },
  {
    "path": "core/test_metrics.py",
    "content": "from dataclasses import dataclass\n\nfrom tml.core import metrics as core_metrics\nfrom tml.core.metric_mixin import MetricMixin, prepend_transform\n\nimport torch\nfrom torchmetrics import MaxMetric, MetricCollection, SumMetric\n\n\n@dataclass\nclass MockStratifierConfig:\n  name: str\n  index: int\n  value: int\n\n\nclass Count(MetricMixin, SumMetric):\n  def transform(self, outputs):\n    return {\"value\": 1}\n\n\nMax = prepend_transform(MaxMetric, lambda outputs: {\"value\": outputs[\"value\"]})\n\n\ndef test_count_metric():\n  num_examples = 123\n  examples = [\n    {\"stuff\": 0},\n  ] * num_examples\n\n  metric = Count()\n  for outputs in examples:\n    metric.update(outputs)\n\n  assert metric.compute().item() == num_examples\n\n\ndef test_collections():\n  max_metric = Max()\n  count_metric = Count()\n  metric = MetricCollection([max_metric, count_metric])\n\n  examples = [{\"value\": idx} for idx in range(123)]\n  for outputs in examples:\n    metric.update(outputs)\n\n  assert metric.compute() == {\n    max_metric.__class__.__name__: len(examples) - 1,\n    count_metric.__class__.__name__: len(examples),\n  }\n\n\ndef test_task_dependent_ctr():\n  num_examples = 144\n  batch_size = 1024\n  outputs = [\n    {\n      \"stuff\": 0,\n      \"labels\": torch.arange(0, 6).repeat(batch_size, 1),\n    }\n    for idx in range(num_examples)\n  ]\n\n  for task_idx in range(5):\n    metric = core_metrics.Ctr(task_idx=task_idx)\n    for output in outputs:\n      metric.update(output)\n    assert metric.compute().item() == task_idx\n\n\ndef test_stratified_ctr():\n  outputs = [\n    {\n      \"stuff\": 0,\n      # [bsz, tasks]\n      \"labels\": torch.tensor(\n        [\n          [0, 1, 2, 3],\n          [1, 2, 3, 4],\n          [2, 3, 4, 0],\n        ]\n      ),\n      \"stratifiers\": {\n        # [bsz]\n        \"level\": torch.tensor(\n          [9, 0, 9],\n        ),\n      },\n    }\n  ]\n\n  stratifier = MockStratifierConfig(name=\"level\", index=2, value=9)\n  for task_idx in range(5):\n    metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier)\n    for output in outputs:\n      metric.update(output)\n    # From the dataset of:\n    # [\n    #   [0, 1, 2, 3],\n    #   [1, 2, 3, 4],\n    #   [2, 3, 4, 0],\n    # ]\n    # we pick out\n    # [\n    #   [0, 1, 2, 3],\n    #   [2, 3, 4, 0],\n    # ]\n    # and with Ctr task_idx, we pick out\n    # [\n    #   [1,],\n    #   [3,],\n    # ]\n    assert metric.compute().item() == (1 + 3) / 2\n\n\ndef test_auc():\n  num_samples = 10000\n  metric = core_metrics.Auc(num_samples)\n  target = torch.tensor([0, 0, 1, 1, 1])\n  preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0])\n  outputs_correct = {\"logits\": preds_correct, \"labels\": target}\n  preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0])\n  outputs_bad = {\"logits\": preds_bad, \"labels\": target}\n\n  metric.update(outputs_correct)\n  assert metric.compute().item() == 1.0\n\n  metric.reset()\n  metric.update(outputs_bad)\n  assert metric.compute().item() == 0.0\n\n\ndef test_pos_rank():\n  metric = core_metrics.PosRanks()\n  target = torch.tensor([0, 0, 1, 1, 1])\n  preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])\n  outputs_correct = {\"logits\": preds_correct, \"labels\": target}\n  preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])\n  outputs_bad = {\"logits\": preds_bad, \"labels\": target}\n\n  metric.update(outputs_correct)\n  assert metric.compute().item() == 2.0\n\n  metric.reset()\n  metric.update(outputs_bad)\n  assert metric.compute().item() == 4.0\n\n\ndef test_reciprocal_rank():\n  metric = core_metrics.ReciprocalRank()\n  target = torch.tensor([0, 0, 1, 1, 1])\n  preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])\n  outputs_correct = {\"logits\": preds_correct, \"labels\": target}\n  preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])\n  outputs_bad = {\"logits\": preds_bad, \"labels\": target}\n\n  metric.update(outputs_correct)\n  assert abs(metric.compute().item() - 0.6111) < 0.001\n\n  metric.reset()\n  metric.update(outputs_bad)\n  assert abs(metric.compute().item() == 0.2611) < 0.001\n\n\ndef test_hit_k():\n  hit1_metric = core_metrics.HitAtK(1)\n  target = torch.tensor([0, 0, 1, 1, 1])\n  preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])\n  outputs_correct = {\"logits\": preds_correct, \"labels\": target}\n  preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])\n  outputs_bad = {\"logits\": preds_bad, \"labels\": target}\n\n  hit1_metric.update(outputs_correct)\n  assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001\n\n  hit1_metric.reset()\n  hit1_metric.update(outputs_bad)\n\n  assert hit1_metric.compute().item() == 0\n\n  hit3_metric = core_metrics.HitAtK(3)\n  hit3_metric.update(outputs_correct)\n  assert (hit3_metric.compute().item() - 0.66666) < 0.0001\n\n  hit3_metric.reset()\n  hit3_metric.update(outputs_bad)\n  assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001\n"
  },
  {
    "path": "core/test_train_pipeline.py",
    "content": "from dataclasses import dataclass\nfrom typing import Tuple\n\nfrom tml.common.batch import DataclassBatch\nfrom tml.common.testing_utils import mock_pg\nfrom tml.core import train_pipeline\n\nimport torch\nfrom torchrec.distributed import DistributedModelParallel\n\n\n@dataclass\nclass MockDataclassBatch(DataclassBatch):\n  continuous_features: torch.Tensor\n  labels: torch.Tensor\n\n\nclass MockModule(torch.nn.Module):\n  def __init__(self) -> None:\n    super().__init__()\n    self.model = torch.nn.Linear(10, 1)\n    self.loss_fn = torch.nn.BCEWithLogitsLoss()\n\n  def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:\n    pred = self.model(batch.continuous_features)\n    loss = self.loss_fn(pred, batch.labels)\n    return (loss, pred)\n\n\ndef create_batch(bsz: int):\n  return MockDataclassBatch(\n    continuous_features=torch.rand(bsz, 10).float(),\n    labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),\n  )\n\n\ndef test_sparse_pipeline():\n  device = torch.device(\"cpu\")\n  model = MockModule().to(device)\n\n  steps = 8\n  example = create_batch(1)\n  dataloader = iter(example for _ in range(steps + 2))\n\n  results = []\n  with mock_pg():\n    d_model = DistributedModelParallel(model)\n    pipeline = train_pipeline.TrainPipelineSparseDist(\n      model=d_model,\n      optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),\n      device=device,\n      grad_accum=2,\n    )\n    for _ in range(steps):\n      results.append(pipeline.progress(dataloader))\n\n  results = [elem.detach().numpy() for elem in results]\n  # Check gradients are accumulated, i.e. results do not change for every 0th and 1th.\n  for first, second in zip(results[::2], results[1::2]):\n    assert first == second, results\n\n  # Check we do update gradients, i.e. results do change for every 1th and 2nd.\n  for first, second in zip(results[1::2], results[2::2]):\n    assert first != second, results\n\n\ndef test_amp():\n  device = torch.device(\"cpu\")\n  model = MockModule().to(device)\n\n  steps = 8\n  example = create_batch(1)\n  dataloader = iter(example for _ in range(steps + 2))\n\n  results = []\n  with mock_pg():\n    d_model = DistributedModelParallel(model)\n    pipeline = train_pipeline.TrainPipelineSparseDist(\n      model=d_model,\n      optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),\n      device=device,\n      enable_amp=True,\n      # Not supported on CPU.\n      enable_grad_scaling=False,\n    )\n    for _ in range(steps):\n      results.append(pipeline.progress(dataloader))\n\n  results = [elem.detach() for elem in results]\n  for value in results:\n    assert value.dtype == torch.bfloat16\n"
  },
  {
    "path": "core/train_pipeline.py",
    "content": "\"\"\"\nTaken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py\nwith TrainPipelineSparseDist.progress modified to support gradient accumulation.\n\n\"\"\"\nimport abc\nfrom dataclasses import dataclass, field\nimport logging\nfrom typing import (\n  Any,\n  cast,\n  Dict,\n  Generic,\n  Iterator,\n  List,\n  Optional,\n  Set,\n  Tuple,\n  TypeVar,\n)\n\nimport torch\nfrom torch.autograd.profiler import record_function\nfrom torch.fx.node import Node\nfrom torchrec.distributed.model_parallel import (\n  DistributedModelParallel,\n  ShardedModule,\n)\nfrom torchrec.distributed.types import Awaitable\nfrom torchrec.modules.feature_processor import BaseGroupedFeatureProcessor\nfrom torchrec.streamable import Multistreamable, Pipelineable\n\n\nlogger: logging.Logger = logging.getLogger(__name__)\n\n\nIn = TypeVar(\"In\", bound=Pipelineable)\nOut = TypeVar(\"Out\")\n\n\nclass TrainPipeline(abc.ABC, Generic[In, Out]):\n  @abc.abstractmethod\n  def progress(self, dataloader_iter: Iterator[In]) -> Out:\n    pass\n\n\ndef _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:\n  assert isinstance(\n    batch, (torch.Tensor, Pipelineable)\n  ), f\"{type(batch)} must implement Pipelineable interface\"\n  return cast(In, batch.to(device=device, non_blocking=non_blocking))\n\n\ndef _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:\n  if stream is None:\n    return\n  torch.cuda.current_stream().wait_stream(stream)\n  # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,\n  # PyTorch uses the \"caching allocator\" for memory allocation for tensors. When a tensor is\n  # freed, its memory is likely to be reused by newly constructed tenosrs.  By default,\n  # this allocator traces whether a tensor is still in use by only the CUDA stream where it\n  # was created.   When a tensor is used by additional CUDA streams, we need to call record_stream\n  # to tell the allocator about all these streams.  Otherwise, the allocator might free the\n  # underlying memory of the tensor once it is no longer used by the creator stream.  This is\n  # a notable programming trick when we write programs using multi CUDA streams.\n  cur_stream = torch.cuda.current_stream()\n  assert isinstance(\n    batch, (torch.Tensor, Multistreamable)\n  ), f\"{type(batch)} must implement Multistreamable interface\"\n  batch.record_stream(cur_stream)\n\n\nclass TrainPipelineBase(TrainPipeline[In, Out]):\n  \"\"\"\n  This class runs training iterations using a pipeline of two stages, each as a CUDA\n  stream, namely, the current (default) stream and `self._memcpy_stream`. For each\n  iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU\n  memory, and the default stream runs forward, backward, and optimization.\n  \"\"\"\n\n  def __init__(\n    self,\n    model: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    device: torch.device,\n  ) -> None:\n    self._model = model\n    self._optimizer = optimizer\n    self._device = device\n    self._memcpy_stream: Optional[torch.cuda.streams.Stream] = (\n      torch.cuda.Stream() if device.type == \"cuda\" else None\n    )\n    self._cur_batch: Optional[In] = None\n    self._connected = False\n\n  def _connect(self, dataloader_iter: Iterator[In]) -> None:\n    cur_batch = next(dataloader_iter)\n    self._cur_batch = cur_batch\n    with torch.cuda.stream(self._memcpy_stream):\n      self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)\n    self._connected = True\n\n  def progress(self, dataloader_iter: Iterator[In]) -> Out:\n    if not self._connected:\n      self._connect(dataloader_iter)\n\n    # Fetch next batch\n    with record_function(\"## next_batch ##\"):\n      next_batch = next(dataloader_iter)\n    cur_batch = self._cur_batch\n    assert cur_batch is not None\n\n    if self._model.training:\n      with record_function(\"## zero_grad ##\"):\n        self._optimizer.zero_grad()\n\n    with record_function(\"## wait_for_batch ##\"):\n      _wait_for_batch(cur_batch, self._memcpy_stream)\n\n    with record_function(\"## forward ##\"):\n      (losses, output) = self._model(cur_batch)\n\n    if self._model.training:\n      with record_function(\"## backward ##\"):\n        torch.sum(losses, dim=0).backward()\n\n    # Copy the next batch to GPU\n    self._cur_batch = cur_batch = next_batch\n    with record_function(\"## copy_batch_to_gpu ##\"):\n      with torch.cuda.stream(self._memcpy_stream):\n        self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)\n\n    # Update\n    if self._model.training:\n      with record_function(\"## optimizer ##\"):\n        self._optimizer.step()\n\n    return output\n\n\nclass Tracer(torch.fx.Tracer):\n  # Disable proxying buffers during tracing. Ideally, proxying buffers would\n  # be disabled, but some models are currently mutating buffer values, which\n  # causes errors during tracing. If those models can be rewritten to not do\n  # that, we can likely remove this line\n  proxy_buffer_attributes = False\n\n  def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:\n    super().__init__()\n    self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []\n\n  def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:\n    if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:\n      return True\n    return super().is_leaf_module(m, module_qualified_name)\n\n\n@dataclass\nclass TrainPipelineContext:\n  # pyre-ignore [4]\n  input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)\n  module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)\n  # pyre-ignore [4]\n  feature_processor_forwards: List[Any] = field(default_factory=list)\n\n\n@dataclass\nclass ArgInfo:\n  # attributes of input batch, e.g. batch.attr1.attr2 call\n  # will produce [\"attr1\", \"attr2\"]\n  input_attrs: List[str]\n  # batch[attr1].attr2 will produce [True, False]\n  is_getitems: List[bool]\n  # name for kwarg of pipelined forward() call or None\n  # for a positional arg\n  name: Optional[str]\n\n\nclass PipelinedForward:\n  def __init__(\n    self,\n    name: str,\n    args: List[ArgInfo],\n    module: ShardedModule,\n    context: TrainPipelineContext,\n    dist_stream: Optional[torch.cuda.streams.Stream],\n  ) -> None:\n    self._name = name\n    self._args = args\n    self._module = module\n    self._context = context\n    self._dist_stream = dist_stream\n\n  # pyre-ignore [2, 24]\n  def __call__(self, *input, **kwargs) -> Awaitable:\n    assert self._name in self._context.input_dist_requests\n    request = self._context.input_dist_requests[self._name]\n    assert isinstance(request, Awaitable)\n    with record_function(\"## wait_sparse_data_dist ##\"):\n      # Finish waiting on the dist_stream,\n      # in case some delayed stream scheduling happens during the wait() call.\n      with torch.cuda.stream(self._dist_stream):\n        data = request.wait()\n\n    # Make sure that both result of input_dist and context\n    # are properly transferred to the current stream.\n    if self._dist_stream is not None:\n      torch.cuda.current_stream().wait_stream(self._dist_stream)\n      cur_stream = torch.cuda.current_stream()\n\n      assert isinstance(\n        data, (torch.Tensor, Multistreamable)\n      ), f\"{type(data)} must implement Multistreamable interface\"\n      # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.\n      data.record_stream(cur_stream)\n\n      ctx = self._context.module_contexts[self._name]\n      ctx.record_stream(cur_stream)\n\n    if len(self._context.feature_processor_forwards) > 0:\n      with record_function(\"## feature_processor ##\"):\n        for sparse_feature in data:\n          if sparse_feature.id_score_list_features is not None:\n            for fp_forward in self._context.feature_processor_forwards:\n              sparse_feature.id_score_list_features = fp_forward(\n                sparse_feature.id_score_list_features\n              )\n\n    return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)\n\n  @property\n  def name(self) -> str:\n    return self._name\n\n  @property\n  def args(self) -> List[ArgInfo]:\n    return self._args\n\n\ndef _start_data_dist(\n  pipelined_modules: List[ShardedModule],\n  batch: In,\n  context: TrainPipelineContext,\n) -> None:\n  context.input_dist_requests.clear()\n  context.module_contexts.clear()\n  for module in pipelined_modules:\n    forward = module.forward\n    assert isinstance(forward, PipelinedForward)\n\n    # Retrieve argument for the input_dist of EBC\n    # is_getitem True means this argument could be retrieved by a list\n    # False means this argument is getting while getattr\n    # and this info was done in the _rewrite_model by tracing the\n    # entire model to get the arg_info_list\n    args = []\n    kwargs = {}\n    for arg_info in forward.args:\n      if arg_info.input_attrs:\n        arg = batch\n        for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):\n          if is_getitem:\n            arg = arg[attr]\n          else:\n            arg = getattr(arg, attr)\n        if arg_info.name:\n          kwargs[arg_info.name] = arg\n        else:\n          args.append(arg)\n      else:\n        args.append(None)\n    # Start input distribution.\n    module_ctx = module.create_context()\n    context.module_contexts[forward.name] = module_ctx\n    context.input_dist_requests[forward.name] = module.input_dist(module_ctx, *args, **kwargs)\n\n  # Call wait on the first awaitable in the input dist for the tensor splits\n  for key, awaitable in context.input_dist_requests.items():\n    context.input_dist_requests[key] = awaitable.wait()\n\n\ndef _get_node_args_helper(\n  # pyre-ignore\n  arguments,\n  num_found: int,\n  feature_processor_arguments: Optional[List[Node]] = None,\n) -> Tuple[List[ArgInfo], int]:\n  \"\"\"\n  Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.\n  It also counts the number of (args + kwargs) found.\n  \"\"\"\n\n  arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]\n  for arg, arg_info in zip(arguments, arg_info_list):\n    if arg is None:\n      num_found += 1\n      continue\n    while True:\n      if not isinstance(arg, torch.fx.Node):\n        break\n      child_node = arg\n\n      if child_node.op == \"placeholder\":\n        num_found += 1\n        break\n      # skip this fp node\n      elif feature_processor_arguments is not None and child_node in feature_processor_arguments:\n        arg = child_node.args[0]\n      elif (\n        child_node.op == \"call_function\"\n        and child_node.target.__module__ == \"builtins\"\n        # pyre-ignore[16]\n        and child_node.target.__name__ == \"getattr\"\n      ):\n        arg_info.input_attrs.insert(0, child_node.args[1])\n        arg_info.is_getitems.insert(0, False)\n        arg = child_node.args[0]\n      elif (\n        child_node.op == \"call_function\"\n        and child_node.target.__module__ == \"_operator\"\n        # pyre-ignore[16]\n        and child_node.target.__name__ == \"getitem\"\n      ):\n        arg_info.input_attrs.insert(0, child_node.args[1])\n        arg_info.is_getitems.insert(0, True)\n        arg = child_node.args[0]\n      else:\n        break\n  return arg_info_list, num_found\n\n\ndef _get_node_args(\n  node: Node, feature_processor_nodes: Optional[List[Node]] = None\n) -> Tuple[List[ArgInfo], int]:\n  num_found = 0\n  pos_arg_info_list, num_found = _get_node_args_helper(\n    node.args, num_found, feature_processor_nodes\n  )\n  kwargs_arg_info_list, num_found = _get_node_args_helper(node.kwargs.values(), num_found)\n\n  # Replace with proper names for kwargs\n  for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list):\n    arg_info_list.name = name\n\n  arg_info_list = pos_arg_info_list + kwargs_arg_info_list\n  return arg_info_list, num_found\n\n\ndef _get_unsharded_module_names_helper(\n  model: torch.nn.Module,\n  path: str,\n  unsharded_module_names: Set[str],\n) -> bool:\n  sharded_children = set()\n  for name, child in model.named_children():\n    curr_path = path + name\n    if isinstance(child, ShardedModule):\n      sharded_children.add(name)\n    else:\n      child_sharded = _get_unsharded_module_names_helper(\n        child,\n        curr_path + \".\",\n        unsharded_module_names,\n      )\n      if child_sharded:\n        sharded_children.add(name)\n\n  if len(sharded_children) > 0:\n    for name, _ in model.named_children():\n      if name not in sharded_children:\n        unsharded_module_names.add(path + name)\n\n  return len(sharded_children) > 0\n\n\ndef _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:\n  \"\"\"\n  Returns a list of top level modules do not contain any sharded sub modules.\n  \"\"\"\n\n  unsharded_module_names: Set[str] = set()\n  _get_unsharded_module_names_helper(\n    model,\n    \"\",\n    unsharded_module_names,\n  )\n  return list(unsharded_module_names)\n\n\ndef _rewrite_model(  # noqa C901\n  model: torch.nn.Module,\n  context: TrainPipelineContext,\n  dist_stream: Optional[torch.cuda.streams.Stream],\n) -> List[ShardedModule]:\n\n  # Get underlying nn.Module\n  if isinstance(model, DistributedModelParallel):\n    model = model.module\n\n  # Collect a list of sharded modules.\n  sharded_modules = {}\n  fp_modules = {}\n  for name, m in model.named_modules():\n    if isinstance(m, ShardedModule):\n      sharded_modules[name] = m\n    if isinstance(m, BaseGroupedFeatureProcessor):\n      fp_modules[name] = m\n\n  # Trace a model.\n  tracer = Tracer(leaf_modules=_get_unsharded_module_names(model))\n  graph = tracer.trace(model)\n\n  feature_processor_nodes = []\n  # find the fp node\n  for node in graph.nodes:\n    if node.op == \"call_module\" and node.target in fp_modules:\n      feature_processor_nodes.append(node)\n  # Select sharded modules, which are top-level in the forward call graph,\n  # i.e. which don't have input transformations, i.e.\n  # rely only on 'builtins.getattr'.\n  ret = []\n  for node in graph.nodes:\n    if node.op == \"call_module\" and node.target in sharded_modules:\n      total_num_args = len(node.args) + len(node.kwargs)\n      if total_num_args == 0:\n        continue\n      arg_info_list, num_found = _get_node_args(node, feature_processor_nodes)\n      if num_found == total_num_args:\n        logger.info(f\"Module '{node.target}'' will be pipelined\")\n        child = sharded_modules[node.target]\n        child.forward = PipelinedForward(\n          node.target,\n          arg_info_list,\n          child,\n          context,\n          dist_stream,\n        )\n        ret.append(child)\n  return ret\n\n\nclass TrainPipelineSparseDist(TrainPipeline[In, Out]):\n  \"\"\"\n  This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with\n  forward and backward. This helps hide the all2all latency while preserving the\n  training forward / backward ordering.\n\n  stage 3: forward, backward - uses default CUDA stream\n  stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream\n  stage 1: device transfer - uses memcpy CUDA stream\n\n  `ShardedModule.input_dist()` is only done for top-level modules in the call graph.\n  To be considered a top-level module, a module can only depend on 'getattr' calls on\n  input.\n\n  Input model must be symbolically traceable with the exception of `ShardedModule` and\n  `DistributedDataParallel` modules.\n  \"\"\"\n\n  synced_pipeline_id: Dict[int, int] = {}\n\n  def __init__(\n    self,\n    model: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    device: torch.device,\n    enable_amp: bool = False,\n    enable_grad_scaling: bool = True,\n    grad_accum: Optional[int] = None,\n  ) -> None:\n    self._model = model\n    self._optimizer = optimizer\n    self._device = device\n    self._enable_amp = enable_amp\n    # NOTE: Pending upstream feedback, but two flags because we can run AMP without CUDA but cannot scale gradients without CUDA.\n    # Background on gradient/loss scaling\n    # https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling\n    # https://pytorch.org/docs/stable/amp.html#gradient-scaling\n    self._enable_grad_scaling = enable_grad_scaling\n    self._grad_scaler = torch.cuda.amp.GradScaler(\n      enabled=self._enable_amp and self._enable_grad_scaling\n    )\n    logging.info(f\"Amp is enabled: {self._enable_amp}\")\n\n    # use two data streams to support two concurrent batches\n    if device.type == \"cuda\":\n      self._memcpy_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()\n      self._data_dist_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()\n    else:\n      if self._enable_amp:\n        logging.warning(\"Amp is enabled, but no CUDA available\")\n      self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None\n      self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None\n    self._batch_i: Optional[In] = None\n    self._batch_ip1: Optional[In] = None\n    self._batch_ip2: Optional[In] = None\n    self._connected = False\n    self._context = TrainPipelineContext()\n    self._pipelined_modules: List[ShardedModule] = []\n\n    self._progress_calls = 0\n    if grad_accum is not None:\n      assert isinstance(grad_accum, int) and grad_accum > 0\n    self._grad_accum = grad_accum\n\n  def _connect(self, dataloader_iter: Iterator[In]) -> None:\n    # batch 1\n    with torch.cuda.stream(self._memcpy_stream):\n      batch_i = next(dataloader_iter)\n      self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)\n      # Try to pipeline input data dist.\n      self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)\n\n    with torch.cuda.stream(self._data_dist_stream):\n      _wait_for_batch(batch_i, self._memcpy_stream)\n      _start_data_dist(self._pipelined_modules, batch_i, self._context)\n\n    # batch 2\n    with torch.cuda.stream(self._memcpy_stream):\n      batch_ip1 = next(dataloader_iter)\n      self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)\n    self._connected = True\n    self.__class__.synced_pipeline_id[id(self._model)] = id(self)\n\n  def progress(self, dataloader_iter: Iterator[In]) -> Out:\n    \"\"\"\n    NOTE: This method has been updated to perform gradient accumulation.\n    If `_grad_accum` is set, then loss values are scaled by this amount and\n    optimizer update/reset is skipped for `_grad_accum` calls of `progress`\n    (congruent to training steps), and then update/reset on every `_grad_accum`th\n    step.\n\n    \"\"\"\n    should_step_optimizer = (\n      self._grad_accum is not None\n      and self._progress_calls > 0\n      and (self._progress_calls + 1) % self._grad_accum == 0\n    ) or self._grad_accum is None\n    should_reset_optimizer = (\n      self._grad_accum is not None\n      and self._progress_calls > 0\n      and (self._progress_calls + 2) % self._grad_accum == 0\n    ) or self._grad_accum is None\n\n    if not self._connected:\n      self._connect(dataloader_iter)\n    elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):\n      self._sync_pipeline()\n      self.__class__.synced_pipeline_id[id(self._model)] = id(self)\n\n    if self._model.training and should_reset_optimizer:\n      with record_function(\"## zero_grad ##\"):\n        self._optimizer.zero_grad()\n\n    with record_function(\"## copy_batch_to_gpu ##\"):\n      with torch.cuda.stream(self._memcpy_stream):\n        batch_ip2 = next(dataloader_iter)\n        self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)\n    batch_i = cast(In, self._batch_i)\n    batch_ip1 = cast(In, self._batch_ip1)\n\n    with record_function(\"## wait_for_batch ##\"):\n      _wait_for_batch(batch_i, self._data_dist_stream)\n\n    # Forward\n    with record_function(\"## forward ##\"):\n      # if using multiple streams (ie. CUDA), create an event in default stream\n      # before starting forward pass\n      if self._data_dist_stream:\n        event = torch.cuda.current_stream().record_event()\n      if self._enable_amp:\n        # conditionally apply the model to the batch in the autocast context\n        # it appears that `enabled=self._enable_amp` should handle this,\n        # but it does not.\n        with torch.autocast(\n          device_type=self._device.type,\n          dtype=torch.bfloat16,\n          enabled=self._enable_amp,\n        ):\n          (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))\n      else:\n        (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))\n\n    # Data Distribution\n    with record_function(\"## sparse_data_dist ##\"):\n      with torch.cuda.stream(self._data_dist_stream):\n        _wait_for_batch(batch_ip1, self._memcpy_stream)\n        # Ensure event in default stream has been called before\n        # starting data dist\n        if self._data_dist_stream:\n          # pyre-ignore [61]: Local variable `event` is undefined, or not always defined\n          self._data_dist_stream.wait_event(event)\n        _start_data_dist(self._pipelined_modules, batch_ip1, self._context)\n\n    if self._model.training:\n      # Backward\n      with record_function(\"## backward ##\"):\n        # Loss is normalize by number of accumulation steps.\n        # The reported loss in `output['loss']` remains the unnormalized value.\n        if self._grad_accum is not None:\n          losses = losses / self._grad_accum\n        self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()\n\n      if should_step_optimizer:\n        # Update\n        with record_function(\"## optimizer ##\"):\n          self._grad_scaler.step(self._optimizer)\n          self._grad_scaler.update()\n\n    self._batch_i = batch_ip1\n    self._batch_ip1 = batch_ip2\n\n    if self._model.training:\n      self._progress_calls += 1\n\n    return output\n\n  def _sync_pipeline(self) -> None:\n    \"\"\"\n    Syncs `PipelinedForward` for sharded modules with context and dist stream of the\n    current train pipeline. Used when switching between train pipelines for the same\n    model.\n    \"\"\"\n    for module in self._pipelined_modules:\n      module.forward._context = self._context\n      module.forward._dist_stream = self._data_dist_stream\n"
  },
  {
    "path": "images/init_venv.sh",
    "content": "#! /bin/sh\n\nif [[ \"$(uname)\" == \"Darwin\" ]]; then\n  echo \"Only supported on Linux.\"\n  exit 1\nfi\n\n# You may need to point this to a version of python 3.10\nPYTHONBIN=\"/opt/ee/python/3.10/bin/python3.10\"\necho Using \"PYTHONBIN=$PYTHONBIN\"\n\n# Put venv in tmp, these things are not made to last, just rebuild.\nVENV_PATH=\"$HOME/tml_venv\"\nrm -rf \"$VENV_PATH\"\n\"$PYTHONBIN\" -m venv \"$VENV_PATH\"\n\n# shellcheck source=/dev/null\n. \"$VENV_PATH/bin/activate\"\n\npip --require-virtual install -U pip\npip --require-virtualenv install --no-deps -r images/requirements.txt\n\nln -s \"$(pwd)\" \"$VENV_PATH/lib/python3.10/site-packages/tml\"\n\necho \"Now run source ${VENV_PATH}/bin/activate\" to get going.\n"
  },
  {
    "path": "images/requirements.txt",
    "content": "absl-py==1.4.0\naiofiles==22.1.0\naiohttp==3.8.3\naiosignal==1.3.1\nappdirs==1.4.4\narrow==1.2.3\nasttokens==2.2.1\nastunparse==1.6.3\nasync-timeout==4.0.2\nattrs==22.1.0\nbackcall==0.2.0\nblack==22.6.0\ncachetools==5.3.0\ncblack==22.6.0\ncertifi==2022.12.7\ncfgv==3.3.1\ncharset-normalizer==2.1.1\nclick==8.1.3\ncmake==3.25.0\nCython==0.29.32\ndecorator==5.1.1\ndistlib==0.3.6\ndistro==1.8.0\ndm-tree==0.1.6\ndocker==6.0.1\ndocker-pycreds==0.4.0\ndocstring-parser==0.8.1\nexceptiongroup==1.1.0\nexecuting==1.2.0\nfbgemm-gpu-cpu==0.3.2\nfilelock==3.8.2\nfire==0.5.0\nflatbuffers==1.12\nfrozenlist==1.3.3\nfsspec==2022.11.0\ngast==0.4.0\ngcsfs==2022.11.0\ngitdb==4.0.10\nGitPython==3.1.31\ngoogle-api-core==2.8.2\ngoogle-auth==2.16.0\ngoogle-auth-oauthlib==0.4.6\ngoogle-cloud-core==2.3.2\ngoogle-cloud-storage==2.7.0\ngoogle-crc32c==1.5.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.4.1\ngoogleapis-common-protos==1.56.4\ngrpcio==1.51.1\nh5py==3.8.0\nhypothesis==6.61.0\nidentify==2.5.17\nidna==3.4\nimportlib-metadata==6.0.0\niniconfig==2.0.0\niopath==0.1.10\nipdb==0.13.11\nipython==8.10.0\njedi==0.18.2\nJinja2==3.1.2\nkeras==2.9.0\nKeras-Preprocessing==1.1.2\nlibclang==15.0.6.1\nlibcst==0.4.9\nMarkdown==3.4.1\nMarkupSafe==2.1.1\nmatplotlib-inline==0.1.6\nmoreorless==0.4.0\nmultidict==6.0.4\nmypy==1.0.1\nmypy-extensions==0.4.3\nnest-asyncio==1.5.6\nninja==1.11.1\nnodeenv==1.7.0\nnumpy==1.22.0\nnvidia-cublas-cu11==11.10.3.66\nnvidia-cuda-nvrtc-cu11==11.7.99\nnvidia-cuda-runtime-cu11==11.7.99\nnvidia-cudnn-cu11==8.5.0.96\noauthlib==3.2.2\nopt-einsum==3.3.0\npackaging==22.0\npandas==1.5.3\nparso==0.8.3\npathspec==0.11.0\npathtools==0.1.2\npexpect==4.8.0\npickleshare==0.7.5\nplatformdirs==3.0.0\npluggy==1.0.0\nportalocker==2.6.0\nportpicker==1.5.2\npre-commit==3.0.4\nprompt-toolkit==3.0.36\nprotobuf==3.20.2\npsutil==5.9.4\nptyprocess==0.7.0\npure-eval==0.2.2\npyarrow==10.0.1\npyasn1==0.4.8\npyasn1-modules==0.2.8\npydantic==1.9.0\npyDeprecate==0.3.2\nPygments==2.14.0\npyparsing==3.0.9\npyre-extensions==0.0.27\npytest==7.2.1\npytest-mypy==0.10.3\npython-dateutil==2.8.2\npytz==2022.6\nPyYAML==6.0.0\nrequests==2.28.1\nrequests-oauthlib==1.3.1\nrsa==4.9\nscikit-build==0.16.3\nsentry-sdk==1.16.0\nsetproctitle==1.3.2\nsix==1.16.0\nsmmap==5.0.0\nsortedcontainers==2.4.0\nstack-data==0.6.2\nstdlibs==2022.10.9\ntabulate==0.9.0\ntensorboard==2.9.0\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.1\ntensorflow==2.9.3\ntensorflow-estimator==2.9.0\ntensorflow-io-gcs-filesystem==0.30.0\ntermcolor==2.2.0\ntoml==0.10.2\ntomli==2.0.1\ntorch==1.13.1\ntorchmetrics==0.11.0\ntorchrec==0.3.2\ntorchsnapshot==0.1.0\ntorchx==0.3.0\ntqdm==4.64.1\ntrailrunner==1.2.1\ntraitlets==5.9.0\ntyping-inspect==0.8.0\ntyping_extensions==4.4.0\nurllib3==1.26.13\nusort==1.0.5\nvirtualenv==20.19.0\nwandb==0.13.11\nwcwidth==0.2.6\nwebsocket-client==1.4.2\nWerkzeug==2.2.3\nwrapt==1.14.1\nyarl==1.8.2\nzipp==3.12.1\n"
  },
  {
    "path": "machines/environment.py",
    "content": "import json\nimport os\nfrom typing import List\n\n\nKF_DDS_PORT: int = 5050\nSLURM_DDS_PORT: int = 5051\nFLIGHT_SERVER_PORT: int = 2222\n\n\ndef on_kf():\n  return \"SPEC_TYPE\" in os.environ\n\n\ndef has_readers():\n  if on_kf():\n    machines_config_env = json.loads(os.environ[\"MACHINES_CONFIG\"])\n    return machines_config_env[\"dataset_worker\"] is not None\n  return os.environ.get(\"HAS_READERS\", \"False\") == \"True\"\n\n\ndef get_task_type():\n  if on_kf():\n    return os.environ[\"SPEC_TYPE\"]\n  return os.environ[\"TASK_TYPE\"]\n\n\ndef is_chief() -> bool:\n  return get_task_type() == \"chief\"\n\n\ndef is_reader() -> bool:\n  return get_task_type() == \"datasetworker\"\n\n\ndef is_dispatcher() -> bool:\n  return get_task_type() == \"datasetdispatcher\"\n\n\ndef get_task_index():\n  if on_kf():\n    pod_name = os.environ[\"MY_POD_NAME\"]\n    return int(pod_name.split(\"-\")[-1])\n  else:\n    raise NotImplementedError\n\n\ndef get_reader_port():\n  if on_kf():\n    return KF_DDS_PORT\n  return SLURM_DDS_PORT\n\n\ndef get_dds():\n  if not has_readers():\n    return None\n  dispatcher_address = get_dds_dispatcher_address()\n  if dispatcher_address:\n    return f\"grpc://{dispatcher_address}\"\n  else:\n    raise ValueError(\"Job does not have DDS.\")\n\n\ndef get_dds_dispatcher_address():\n  if not has_readers():\n    return None\n  if on_kf():\n    job_name = os.environ[\"JOB_NAME\"]\n    dds_host = f\"{job_name}-datasetdispatcher-0\"\n  else:\n    dds_host = os.environ[\"SLURM_JOB_NODELIST_HET_GROUP_0\"]\n  return f\"{dds_host}:{get_reader_port()}\"\n\n\ndef get_dds_worker_address():\n  if not has_readers():\n    return None\n  if on_kf():\n    job_name = os.environ[\"JOB_NAME\"]\n    task_index = get_task_index()\n    return f\"{job_name}-datasetworker-{task_index}:{get_reader_port()}\"\n  else:\n    node = os.environ[\"SLURMD_NODENAME\"]\n    return f\"{node}:{get_reader_port()}\"\n\n\ndef get_num_readers():\n  if not has_readers():\n    return 0\n  if on_kf():\n    machines_config_env = json.loads(os.environ[\"MACHINES_CONFIG\"])\n    return int(machines_config_env[\"num_dataset_workers\"] or 0)\n  return len(os.environ[\"SLURM_JOB_NODELIST_HET_GROUP_1\"].split(\",\"))\n\n\ndef get_flight_server_addresses():\n  if on_kf():\n    job_name = os.environ[\"JOB_NAME\"]\n    return [\n      f\"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}\"\n      for task_index in range(get_num_readers())\n    ]\n  else:\n    raise NotImplementedError\n\n\ndef get_dds_journaling_dir():\n  return os.environ.get(\"DATASET_JOURNALING_DIR\", None)\n"
  },
  {
    "path": "machines/get_env.py",
    "content": "import tml.machines.environment as env\n\nfrom absl import app, flags\n\n\nFLAGS = flags.FLAGS\nflags.DEFINE_string(\"property\", None, \"Which property of the current environment to fetch.\")\n\n\ndef main(argv):\n  if FLAGS.property == \"using_dds\":\n    print(f\"{env.has_readers()}\", flush=True)\n  if FLAGS.property == \"has_readers\":\n    print(f\"{env.has_readers()}\", flush=True)\n  elif FLAGS.property == \"get_task_type\":\n    print(f\"{env.get_task_type()}\", flush=True)\n  elif FLAGS.property == \"is_datasetworker\":\n    print(f\"{env.is_reader()}\", flush=True)\n  elif FLAGS.property == \"is_dds_dispatcher\":\n    print(f\"{env.is_dispatcher()}\", flush=True)\n  elif FLAGS.property == \"get_task_index\":\n    print(f\"{env.get_task_index()}\", flush=True)\n  elif FLAGS.property == \"get_dataset_service\":\n    print(f\"{env.get_dds()}\", flush=True)\n  elif FLAGS.property == \"get_dds_dispatcher_address\":\n    print(f\"{env.get_dds_dispatcher_address()}\", flush=True)\n  elif FLAGS.property == \"get_dds_worker_address\":\n    print(f\"{env.get_dds_worker_address()}\", flush=True)\n  elif FLAGS.property == \"get_dds_port\":\n    print(f\"{env.get_reader_port()}\", flush=True)\n  elif FLAGS.property == \"get_dds_journaling_dir\":\n    print(f\"{env.get_dds_journaling_dir()}\", flush=True)\n  elif FLAGS.property == \"should_start_dds\":\n    print(env.is_reader() or env.is_dispatcher(), flush=True)\n\n\nif __name__ == \"__main__\":\n  app.run(main)\n"
  },
  {
    "path": "machines/is_venv.py",
    "content": "\"\"\"This is intended to be run as a module.\ne.g. python -m tml.machines.is_venv\n\nExits with 0 ii running in venv, otherwise 1.\n\"\"\"\n\nimport sys\nimport logging\n\n\ndef is_venv():\n  # See https://stackoverflow.com/questions/1871549/determine-if-python-is-running-inside-virtualenv\n  return sys.base_prefix != sys.prefix\n\n\ndef _main():\n  if is_venv():\n    logging.info(\"In venv %s\", sys.prefix)\n    sys.exit(0)\n  else:\n    logging.error(\"Not in venv\")\n    sys.exit(1)\n\n\nif __name__ == \"__main__\":\n  _main()\n"
  },
  {
    "path": "machines/list_ops.py",
    "content": "\"\"\"\nSimple str.split() parsing of input string\n\nusage example:\n  python list_ops.py --input_list=$INPUT [--sep=\",\"] [--op=<len|select>] [--elem=$INDEX]\n\nArgs:\n  - input_list: input string\n  - sep (default \",\"): separator string\n  - elem (default 0): integer index\n  - op (default \"select\"): either `len` or `select`\n    - len: prints len(input_list.split(sep))\n    - select: prints input_list.split(sep)[elem]\n\nTypical usage would be in a bash script, e.g.:\n\n  LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len)\n\n\"\"\"\nimport tml.machines.environment as env\n\nfrom absl import app, flags\n\n\nFLAGS = flags.FLAGS\nflags.DEFINE_string(\"input_list\", None, \"string to parse as list\")\nflags.DEFINE_integer(\"elem\", 0, \"which element to take\")\nflags.DEFINE_string(\"sep\", \",\", \"separator\")\nflags.DEFINE_string(\"op\", \"select\", \"operation to do\")\n\n\ndef main(argv):\n  split_list = FLAGS.input_list.split(FLAGS.sep)\n  if FLAGS.op == \"select\":\n    print(split_list[FLAGS.elem], flush=True)\n  elif FLAGS.op == \"len\":\n    print(len(split_list), flush=True)\n  else:\n    raise ValueError(f\"operation {FLAGS.op} not recognized.\")\n\n\nif __name__ == \"__main__\":\n  app.run(main)\n"
  },
  {
    "path": "metrics/__init__.py",
    "content": "from .aggregation import StableMean  # noqa\nfrom .auroc import AUROCWithMWU  # noqa\nfrom .rce import NRCE, RCE  # noqa\n"
  },
  {
    "path": "metrics/aggregation.py",
    "content": "\"\"\"\nContains aggregation metrics.\n\"\"\"\nfrom typing import Tuple, Union\n\nimport torch\nimport torchmetrics\n\n\ndef update_mean(\n  current_mean: torch.Tensor,\n  current_weight_sum: torch.Tensor,\n  value: torch.Tensor,\n  weight: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n  \"\"\"\n  Update the mean according to Welford formula:\n  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.\n  See also https://nullbuffer.com/articles/welford_algorithm.html for more information.\n  Args:\n    current_mean: The value of the current accumulated mean.\n    current_weight_sum: The current weighted sum.\n    value: The new value that needs to be added to get a new mean.\n    weight: The weights for the new value.\n\n  Returns: The updated mean and updated weighted sum.\n\n  \"\"\"\n  weight = torch.broadcast_to(weight, value.shape)\n\n  # Avoiding (on purpose) in-place operation when using += in case\n  # current_mean and current_weight_sum share the same storage\n  current_weight_sum = current_weight_sum + torch.sum(weight)\n  current_mean = current_mean + torch.sum((weight / current_weight_sum) * (value - current_mean))\n  return current_mean, current_weight_sum\n\n\ndef stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:\n  \"\"\"\n  Merge the state from multiple workers.\n  Args:\n    state: A tensor with the first dimension indicating workers.\n\n  Returns: The accumulated mean from all workers.\n\n  \"\"\"\n  mean, weight_sum = update_mean(\n    current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),\n    current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),\n    value=state[:, 0],\n    weight=state[:, 1],\n  )\n  return torch.stack([mean, weight_sum])\n\n\nclass StableMean(torchmetrics.Metric):\n  \"\"\"\n  This implements a numerical stable mean metrics computation using Welford algorithm according to\n  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.\n  For example when using float32, the algorithm will give a valid output even if the \"sum\" is larger\n   than the maximum float32 as far as the mean is within the limit of float32.\n  See also https://nullbuffer.com/articles/welford_algorithm.html for more information.\n  \"\"\"\n\n  def __init__(self, **kwargs):\n    \"\"\"\n    Args:\n      **kwargs: Additional parameters supported by all torchmetrics.Metric.\n    \"\"\"\n    super().__init__(**kwargs)\n    self.add_state(\n      \"mean_and_weight_sum\",\n      default=torch.zeros(2),\n      dist_reduce_fx=stable_mean_dist_reduce_fn,\n    )\n\n  def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:\n    \"\"\"\n    Update the current mean.\n    Args:\n      value: Value to update the mean with.\n      weight: weight to use. Shape should be broadcastable to that of value.\n    \"\"\"\n    mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]\n\n    if not isinstance(weight, torch.Tensor):\n      weight = torch.as_tensor(weight, dtype=value.dtype, device=value.device)\n\n    self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] = update_mean(\n      mean, weight_sum, value, torch.as_tensor(weight)\n    )\n\n  def compute(self) -> torch.Tensor:\n    \"\"\"\n    Compute and return the accumulated mean.\n    \"\"\"\n    return self.mean_and_weight_sum[0]\n"
  },
  {
    "path": "metrics/auroc.py",
    "content": "\"\"\"\nAUROC metrics.\n\"\"\"\nfrom typing import Union\n\nfrom tml.ml_logging.torch_logging import logging\n\nimport torch\nimport torchmetrics\nfrom torchmetrics.utilities.data import dim_zero_cat\n\n\ndef _compute_helper(\n  predictions: torch.Tensor,\n  target: torch.Tensor,\n  weights: torch.Tensor,\n  max_positive_negative_weighted_sum: torch.Tensor,\n  min_positive_negative_weighted_sum: torch.Tensor,\n  equal_predictions_as_incorrect: bool,\n) -> torch.Tensor:\n  \"\"\"\n  Compute AUROC.\n  Args:\n    predictions: The predictions probabilities.\n    target: The target.\n    weights: The sample weights to assign to each sample in the batch.\n    max_positive_negative_weighted_sum: The sum of the weights for the positive labels.\n    min_positive_negative_weighted_sum:\n    equal_predictions_as_incorrect: For positive & negative labels having identical scores,\n     we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,\n     we assume that they are correct prediction (i.e weight = 0).\n  \"\"\"\n  dim = 0\n\n  # Sort predictions based on key (score, true_label). The order is ascending for score.\n  # For true_label, order is ascending if equal_predictions_as_incorrect is True;\n  # otherwise it is descending.\n  target_order = torch.argsort(target, dim=dim, descending=equal_predictions_as_incorrect)\n  score_order = torch.sort(torch.gather(predictions, dim, target_order), stable=True, dim=dim)[1]\n  score_order = torch.gather(target_order, dim, score_order)\n  sorted_target = torch.gather(target, dim, score_order)\n  sorted_weights = torch.gather(weights, dim, score_order)\n\n  negatives_from_left = torch.cumsum((1.0 - sorted_target) * sorted_weights, 0)\n\n  numerator = torch.sum(\n    sorted_weights * (sorted_target * negatives_from_left / max_positive_negative_weighted_sum)\n  )\n\n  return numerator / min_positive_negative_weighted_sum\n\n\nclass AUROCWithMWU(torchmetrics.Metric):\n  \"\"\"\n  AUROC using Mann-Whitney U-test.\n  See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.\n\n  This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return\n  the correct AUROC even if the predicted probabilities are all close to 0.\n  Currently only support binary classification.\n  \"\"\"\n\n  def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):\n    \"\"\"\n\n    Args:\n      label_threshold: Labels strictly above this threshold are considered positive labels,\n                       otherwise, they are considered negative.\n      raise_missing_class: If True, an error will be raise if negative or positive class is missing.\n        Otherwise, we will simply log a warning.\n      **kwargs: Additional parameters supported by all torchmetrics.Metric.\n    \"\"\"\n    super().__init__(**kwargs)\n    self.add_state(\"predictions\", default=[], dist_reduce_fx=\"cat\")\n    self.add_state(\"target\", default=[], dist_reduce_fx=\"cat\")\n    self.add_state(\"weights\", default=[], dist_reduce_fx=\"cat\")\n\n    self.label_threshold = label_threshold\n    self.raise_missing_class = raise_missing_class\n\n  def update(\n    self,\n    predictions: torch.Tensor,\n    target: torch.Tensor,\n    weight: Union[float, torch.Tensor] = 1.0,\n  ) -> None:\n    \"\"\"\n    Update the current auroc.\n    Args:\n      predictions: Predicted values, 1D Tensor or 2D Tensor of shape batch_size x 1.\n      target: Ground truth. Must have same shape as predictions.\n      weight: The weight to use for the predicted values. Shape should be\n      broadcastable to that of predictions.\n    \"\"\"\n    self.predictions.append(predictions)\n    self.target.append(target)\n    if not isinstance(weight, torch.Tensor):\n      weight = torch.as_tensor(weight, dtype=predictions.dtype, device=target.device)\n    self.weights.append(torch.broadcast_to(weight, predictions.size()))\n\n  def compute(self) -> torch.Tensor:\n    \"\"\"\n    Compute and return the accumulated AUROC.\n    \"\"\"\n    weights = dim_zero_cat(self.weights)\n    predictions = dim_zero_cat(self.predictions)\n    target = dim_zero_cat(self.target).type_as(predictions)\n\n    negative_mask = target <= self.label_threshold\n    positive_mask = torch.logical_not(negative_mask)\n\n    if not negative_mask.any():\n      msg = \"Negative class missing. AUROC returned will be meaningless.\"\n      if self.raise_missing_class:\n        raise ValueError(msg)\n      else:\n        logging.warn(msg)\n    if not positive_mask.any():\n      msg = \"Positive class missing. AUROC returned will be meaningless.\"\n      if self.raise_missing_class:\n        raise ValueError(msg)\n      else:\n        logging.warn(msg)\n\n    weighted_actual_negative_sum = torch.sum(\n      torch.where(negative_mask, weights, torch.zeros_like(weights))\n    )\n\n    weighted_actual_positive_sum = torch.sum(\n      torch.where(positive_mask, weights, torch.zeros_like(weights))\n    )\n\n    max_positive_negative_weighted_sum = torch.max(\n      weighted_actual_negative_sum, weighted_actual_positive_sum\n    )\n\n    min_positive_negative_weighted_sum = torch.min(\n      weighted_actual_negative_sum, weighted_actual_positive_sum\n    )\n\n    # Compute auroc with the weight set to 1 when positive & negative have identical scores.\n    auroc_le = _compute_helper(\n      target=target,\n      weights=weights,\n      predictions=predictions,\n      min_positive_negative_weighted_sum=min_positive_negative_weighted_sum,\n      max_positive_negative_weighted_sum=max_positive_negative_weighted_sum,\n      equal_predictions_as_incorrect=False,\n    )\n\n    # Compute auroc with the weight set to 0 when positive & negative have identical scores.\n    auroc_lt = _compute_helper(\n      target=target,\n      weights=weights,\n      predictions=predictions,\n      min_positive_negative_weighted_sum=min_positive_negative_weighted_sum,\n      max_positive_negative_weighted_sum=max_positive_negative_weighted_sum,\n      equal_predictions_as_incorrect=True,\n    )\n\n    # Compute auroc with the weight set to 1/2 when positive & negative have identical scores.\n    return auroc_le - (auroc_le - auroc_lt) / 2.0\n"
  },
  {
    "path": "metrics/rce.py",
    "content": "\"\"\"\nContains RCE metrics.\n\"\"\"\nimport copy\nfrom functools import partial\nfrom typing import Union\n\nfrom tml.metrics import aggregation\n\nimport torch\nimport torchmetrics\n\n\ndef _smooth(\n  value: torch.Tensor, label_smoothing: Union[float, torch.Tensor]\n) -> Union[float, torch.Tensor]:\n  \"\"\"\n  Smooth given values.\n  Args:\n    value: Value to smooth.\n    label_smoothing: smoothing constant.\n  Returns: Smoothed values.\n  \"\"\"\n  return value * (1.0 - label_smoothing) + 0.5 * label_smoothing\n\n\ndef _binary_cross_entropy_with_clipping(\n  predictions: torch.Tensor,\n  target: torch.Tensor,\n  epsilon: Union[float, torch.Tensor],\n  reduction: str = \"none\",\n) -> torch.Tensor:\n  \"\"\"\n  Clip Predictions and apply binary cross entropy.\n  This is done to match the implementation in keras at\n  https://github.com/keras-team/keras/blob/r2.9/keras/backend.py#L5294-L5300\n  Args:\n    predictions: Predicted probabilities.\n    target: Ground truth.\n    epsilon: Epsilon fuzz factor used to clip the predictions.\n    reduction: The reduction method to use.\n\n  Returns: Binary cross entropy on the clipped predictions.\n\n  \"\"\"\n  predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon)\n  bce = -target * torch.log(predictions + epsilon)\n  bce -= (1.0 - target) * torch.log(1.0 - predictions + epsilon)\n  if reduction == \"mean\":\n    return torch.mean(bce)\n  return bce\n\n\nclass RCE(torchmetrics.Metric):\n  \"\"\"\n  Compute the relative cross entropy (`RCE <http://go/rce>`_).\n\n  RCE is metric used for models predicting probability of success (p), i.e. pCTR.\n  RCE represents the binary `cross entropy <https://en.wikipedia.org/wiki/Cross_entropy>` of\n  the model compared to a reference straw man model.\n\n  Binary cross entropy is defined as:\n\n  y = label; p = prediction;\n  binary cross entropy(example) = - y * log(p) - (1-y) * log(1-p)\n\n  Where y in {0, 1}\n\n  Cross entropy of a model is defined as:\n\n  CE(model) = average(binary cross entropy(example))\n\n  Over all the examples we aggregate on.\n\n  The straw man model is quite simple, it is a constant predictor, always predicting the average\n  over the labels.\n\n  RCE of a model is defined as:\n\n  RCE(model) = 100 * (CE(reference model) - CE(model)) / CE(reference model)\n\n  .. note:: Maximizing the likelihood is the same as minimizing the cross entropy or maximizing\n            the RCE. Since cross entropy is the average minus likelihood for the binary case.\n\n  .. note:: Binary cross entropy of an example is non negative, and equal to the\n            `KL divergence <(https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence\n            #Properties>`\n            since p is constant, and its entropy is equal to zero.\n\n  .. note:: 0% RCE means as good as the straw man model.\n            100% means always predicts exactly the label. Namely, cross entropy of the model is\n                always zero. In practice 100% is impossible to achieve due to clipping.\n            Negative RCE means that the model is doing worse than the straw man.\n            This usually means an un-calibrated model, namely, the average prediction\n            is \"far\" from the average label. Examining NRCE might help identifying if that is\n            the case.\n\n  .. note:: RCE is not a \"ratio\" in the statistical\n            `level of measurement sense <https://en.wikipedia.org/wiki/Level_of_measurement>`.\n            The higher the model's RCE is the harder it is to improve it by an extra point.\n\n            For example:\n            Let CE(model) = 0.5 CE(reference model), then the RCE(model) = 50.\n            Now take a \"twice as good\" model:\n            Let CE(better model) = 0.5 CE(model) = 0.25 CE(reference model),\n            then the RCE(better model) = 75 and not 100.\n\n  .. note:: In order to keep the log function stable, typically p is limited to\n            lie in [CLAMP_EPSILON, 1-CLAMP_EPSILON],\n            where CLAMP_EPSILON is some small constant like: 1e-7.\n            Old implementation used 1e-5 clipping by default, current uses\n            tf.keras.backend.epsilon()\n            whose default is 1e-7.\n\n  .. note:: Since the reference model prediction is constant (probability),\n            CE(reference model) = H(average(label))\n\n            Where H is the standard\n            `entropy <https://en.wikipedia.org/wiki/Entropy_(information_theory)>` function.\n\n  .. note:: Must have at least 1 positive and 1 negative sample accumulated,\n            or RCE will come out as NaN.\n  \"\"\"\n\n  def __init__(\n    self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs\n  ):\n    \"\"\"\n    Args:\n      from_logits: whether or not predictions are logits or probabilities.\n      label_smoothing: label smoothing constant.\n      epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.\n      **kwargs: Additional parameters supported by all torchmetrics.Metric.\n    \"\"\"\n    super().__init__(**kwargs)\n    self.from_logits = from_logits\n    self.label_smoothing = label_smoothing\n    self.epsilon = epsilon\n    self.kwargs = kwargs\n\n    self.mean_label = aggregation.StableMean(**kwargs)\n    self.binary_cross_entropy = aggregation.StableMean(**kwargs)\n\n    if self.from_logits:\n      self.bce_loss_fn = torch.nn.functional.binary_cross_entropy_with_logits\n    else:\n      self.bce_loss_fn = partial(_binary_cross_entropy_with_clipping, epsilon=self.epsilon)\n\n    # Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.\n    self.batch_metric = copy.deepcopy(self)\n\n  def update(\n    self, predictions: torch.Tensor, target: torch.Tensor, weight: float = 1.0\n  ) -> torch.Tensor:\n    \"\"\"\n    Update the current rce.\n    Args:\n      predictions: Predicted values.\n      target: Ground truth. Should have same shape as predictions.\n      weight: The weight to use for the predicted values. Shape should be broadcastable to that of\n       predictions.\n    \"\"\"\n    target = _smooth(target, self.label_smoothing)\n    self.mean_label.update(target, weight)\n    self.binary_cross_entropy.update(\n      self.bce_loss_fn(predictions, target, reduction=\"none\"), weight\n    )\n\n  def compute(self) -> torch.Tensor:\n    \"\"\"\n    Compute and return the accumulated rce.\n    \"\"\"\n    baseline_mean = self.mean_label.compute()\n\n    baseline_ce = _binary_cross_entropy_with_clipping(\n      baseline_mean, baseline_mean, reduction=\"mean\", epsilon=self.epsilon\n    )\n\n    pred_ce = self.binary_cross_entropy.compute()\n\n    return (1.0 - (pred_ce / baseline_ce)) * 100\n\n  def reset(self):\n    \"\"\"\n    Reset the metric to its initial state.\n    \"\"\"\n    super().reset()\n    self.mean_label.reset()\n    self.binary_cross_entropy.reset()\n\n  def forward(self, *args, **kwargs):\n    \"\"\"\n    Serves the dual purpose of both computing the metric on the current batch of inputs but also\n        add the batch statistics to the overall accumulating metric state.\n    Input arguments are the exact same as corresponding ``update`` method.\n    The returned output is the exact same as the output of ``compute``.\n    \"\"\"\n    self.update(*args, **kwargs)\n    self.batch_metric.update(*args, **kwargs)\n    batch_result = self.batch_metric.compute()\n    self.batch_metric.reset()\n    return batch_result\n\n\nclass NRCE(RCE):\n  \"\"\"\n  Calculate the RCE of the normalizes model.\n  Where the normalized model prediction average is normalized to the average label seen so far.\n  Namely, the the normalized model prediction:\n\n  normalized model prediction(example) = (model prediction(example) * average(label)) /\n  average(model prediction)\n\n  Where the average is over all previously seen examples.\n\n  .. note:: average(normalized model prediction) = average(label)\n\n  .. note:: NRCE can be misleading since it is oblivious to mis-calibrations.\n            The common interpretation of NRCE is to measure how good your model could potentially\n            perform if it was well calibrated.\n\n  .. note:: A big gap between NRCE and RCE might indicate a badly calibrated model,\n\n  \"\"\"\n\n  def __init__(\n    self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs\n  ):\n    \"\"\"\n\n    Args:\n      from_logits: whether or not predictions are logits or probabilities.\n      label_smoothing: label smoothing constant.\n      epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.\n               It only used when computing the cross entropy but not when normalizing.\n      **kwargs: Additional parameters supported by all torchmetrics.Metric.\n    \"\"\"\n    super().__init__(from_logits=False, label_smoothing=0, epsilon=epsilon, **kwargs)\n    self.nrce_from_logits = from_logits\n    self.nrce_label_smoothing = label_smoothing\n    self.mean_prediction = aggregation.StableMean()\n\n    # Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.\n    self.batch_metric = copy.deepcopy(self)\n\n  def update(\n    self,\n    predictions: torch.Tensor,\n    target: torch.Tensor,\n    weight: Union[float, torch.Tensor] = 1.0,\n  ):\n    \"\"\"\n    Update the current nrce.\n    Args:\n      predictions: Predicted values.\n      target: Ground truth. Should have same shape as predictions.\n      weight: The weight to use for the predicted values. Shape should be broadcastable to that of\n       predictions.\n    \"\"\"\n    predictions = torch.sigmoid(predictions) if self.nrce_from_logits else predictions\n\n    target = _smooth(target, self.nrce_label_smoothing)\n    self.mean_label.update(target, weight)\n\n    self.mean_prediction.update(predictions, weight)\n\n    normalizer = self.mean_label.compute() / self.mean_prediction.compute()\n\n    predictions = predictions * normalizer\n\n    self.binary_cross_entropy.update(\n      self.bce_loss_fn(predictions, target, reduction=\"none\"), weight\n    )\n\n  def reset(self):\n    \"\"\"\n    Reset the metric to its initial state.\n    \"\"\"\n    super().reset()\n    self.mean_prediction.reset()\n"
  },
  {
    "path": "ml_logging/__init__.py",
    "content": ""
  },
  {
    "path": "ml_logging/absl_logging.py",
    "content": "\"\"\"Sets up logging through absl for training usage.\n\n- Redirects logging to sys.stdout so that severity levels in GCP Stackdriver are accurate.\n\nUsage:\n    >>> from twitter.ml.logging.absl_logging import logging\n    >>> logging.info(f\"Properly logged as INFO level in GCP Stackdriver.\")\n\n\"\"\"\nimport logging as py_logging\nimport sys\n\nfrom absl import logging as logging\n\n\ndef setup_absl_logging():\n  \"\"\"Make sure that absl logging pushes to stdout rather than stderr.\"\"\"\n  logging.get_absl_handler().python_handler.stream = sys.stdout\n  formatter = py_logging.Formatter(\n    fmt=\"[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s\"\n  )\n  logging.get_absl_handler().setFormatter(formatter)\n  logging.set_verbosity(logging.INFO)\n\n\nsetup_absl_logging()\n"
  },
  {
    "path": "ml_logging/test_torch_logging.py",
    "content": "import unittest\n\nfrom tml.ml_logging.torch_logging import logging\n\n\nclass Testtlogging(unittest.TestCase):\n  def test_warn_once(self):\n    with self.assertLogs(level=\"INFO\") as captured_logs:\n      logging.info(\"first info\")\n      logging.warning(\"first warning\")\n      logging.warning(\"first warning\")\n      logging.info(\"second info\")\n\n    self.assertEqual(\n      captured_logs.output,\n      [\n        \"INFO:absl:first info\",\n        \"WARNING:absl:first warning\",\n        \"INFO:absl:second info\",\n      ],\n    )\n"
  },
  {
    "path": "ml_logging/torch_logging.py",
    "content": "\"\"\"Overrides absl logger to be rank-aware for distributed pytorch usage.\n\n    >>> # in-bazel import\n    >>> from twitter.ml.logging.torch_logging import logging\n    >>> # out-bazel import\n    >>> from ml.logging.torch_logging import logging\n    >>> logging.info(f\"This only prints on rank 0 if distributed, otherwise prints normally.\")\n    >>> logging.info(f\"This prints on all ranks if distributed, otherwise prints normally.\", rank=-1)\n\n\"\"\"\nimport functools\nfrom typing import Optional\n\nfrom tml.ml_logging.absl_logging import logging as logging\nfrom absl import logging as absl_logging\n\nimport torch.distributed as dist\n\n\ndef rank_specific(logger):\n  \"\"\"Ensures that we only override a given logger once.\"\"\"\n  if hasattr(logger, \"_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC\"):\n    return logger\n\n  def _if_rank(logger_method, limit: Optional[int] = None):\n    if limit:\n      # If we are limiting redundant logs, wrap logging call with a cache\n      # to not execute if already cached.\n      def _wrap(_call):\n        @functools.lru_cache(limit)\n        def _logger_method(*args, **kwargs):\n          _call(*args, **kwargs)\n\n        return _logger_method\n\n      logger_method = _wrap(logger_method)\n\n    def _inner(msg, *args, rank: int = 0, **kwargs):\n      if not dist.is_initialized():\n        logger_method(msg, *args, **kwargs)\n      elif dist.get_rank() == rank:\n        logger_method(msg, *args, **kwargs)\n      elif rank < 0:\n        logger_method(f\"Rank{dist.get_rank()}: {msg}\", *args, **kwargs)\n\n    # Register this stack frame with absl logging so that it doesn't trample logging lines.\n    absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)\n\n    return _inner\n\n  logger.fatal = _if_rank(logger.fatal)\n  logger.error = _if_rank(logger.error)\n  logger.warning = _if_rank(logger.warning, limit=1)\n  logger.info = _if_rank(logger.info)\n  logger.debug = _if_rank(logger.debug)\n  logger.exception = _if_rank(logger.exception)\n\n  logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True\n\n\nrank_specific(logging)\n"
  },
  {
    "path": "model.py",
    "content": "\"\"\"Wraps servable model in loss and RecapBatch passing to be trainable.\"\"\"\n# flake8: noqa\nfrom typing import Callable\n\nfrom tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]\n\nimport torch\nimport torch.distributed as dist\nfrom torchrec.distributed.model_parallel import DistributedModelParallel\n\n\nclass ModelAndLoss(torch.nn.Module):\n  # Reconsider our approach at a later date: https://ppwwyyxx.com/blog/2022/Loss-Function-Separation/\n\n  def __init__(\n    self,\n    model,\n    loss_fn: Callable,\n  ) -> None:\n    \"\"\"\n    Args:\n      model: torch module to wrap.\n      loss_fn: Function for calculating loss, should accept logits and labels.\n    \"\"\"\n    super().__init__()\n    self.model = model\n    self.loss_fn = loss_fn\n\n  def forward(self, batch: \"RecapBatch\"):  # type: ignore[name-defined]\n    \"\"\"Runs model forward and calculates loss according to given loss_fn.\n\n    NOTE: The input signature here needs to be a Pipelineable object for\n    prefetching purposes during training using torchrec's pipeline.  However\n    the underlying model signature needs to be exportable to onnx, requiring\n    generic python types.  see https://pytorch.org/docs/stable/onnx.html#types.\n\n    \"\"\"\n    outputs = self.model(batch)\n    losses = self.loss_fn(outputs[\"logits\"], batch.labels.float(), batch.weights.float())\n\n    outputs.update(\n      {\n        \"loss\": losses,\n        \"labels\": batch.labels,\n        \"weights\": batch.weights,\n      }\n    )\n\n    # Allow multiple losses.\n    return losses, outputs\n\n\ndef maybe_shard_model(\n  model,\n  device: torch.device,\n):\n  \"\"\"Set up and apply DistributedModelParallel to a model if running in a distributed environment.\n\n    If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies\n    DistributedModelParallel.\n\n  If not in a distributed environment, returns model directly.\n  \"\"\"\n  if dist.is_initialized():\n    logging.info(\"***** Wrapping in DistributedModelParallel *****\")\n    logging.info(f\"Model before wrapping: {model}\")\n    model = DistributedModelParallel(\n      module=model,\n      device=device,\n    )\n    logging.info(f\"Model after wrapping: {model}\")\n\n  return model\n\n\ndef log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:\n  \"\"\"Handy function to log the content of EBC embedding layer.\n     Only works for single GPU machines.\n\n  Args:\n      weight_name: name of tensor, as defined in model\n      table_name: name of the EBC table the weight is taken from\n      weight_tensor: embedding weight tensor\n  \"\"\"\n  logging.info(f\"{weight_name}, {table_name}\", rank=-1)\n  logging.info(f\"{weight_tensor.metadata()}\", rank=-1)\n  output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device(\"cuda:0\"))\n  weight_tensor.gather(out=output_tensor)\n  logging.info(f\"{output_tensor}\", rank=-1)\n"
  },
  {
    "path": "optimizers/__init__.py",
    "content": "from tml.optimizers.optimizer import compute_lr\n"
  },
  {
    "path": "optimizers/config.py",
    "content": "\"\"\"Optimization configurations for models.\"\"\"\n\nimport typing\n\nimport tml.core.config as base_config\n\nimport pydantic\n\n\nclass PiecewiseConstant(base_config.BaseConfig):\n  learning_rate_boundaries: typing.List[int] = pydantic.Field(None)\n  learning_rate_values: typing.List[float] = pydantic.Field(None)\n\n\nclass LinearRampToConstant(base_config.BaseConfig):\n  learning_rate: float\n  num_ramp_steps: pydantic.PositiveInt = pydantic.Field(\n    description=\"Number of steps to ramp this up from zero.\"\n  )\n\n\nclass LinearRampToCosine(base_config.BaseConfig):\n  learning_rate: float\n  final_learning_rate: float\n  num_ramp_steps: pydantic.PositiveInt = pydantic.Field(\n    description=\"Number of steps to ramp this up from zero.\"\n  )\n  final_num_steps: pydantic.PositiveInt = pydantic.Field(\n    description=\"Final number of steps where decay stops.\"\n  )\n\n\nclass LearningRate(base_config.BaseConfig):\n  constant: float = pydantic.Field(None, one_of=\"lr\")\n  linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of=\"lr\")\n  linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of=\"lr\")\n  piecewise_constant: PiecewiseConstant = pydantic.Field(None, one_of=\"lr\")\n\n\nclass OptimizerAlgorithmConfig(base_config.BaseConfig):\n  \"\"\"Base class for optimizer configurations.\"\"\"\n\n  lr: float\n  ...\n\n\nclass AdamConfig(OptimizerAlgorithmConfig):\n  # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam\n  lr: float\n  betas: typing.Tuple[float, float] = [0.9, 0.999]\n  eps: float = 1e-7  # Numerical stability in denominator.\n\n\nclass SgdConfig(OptimizerAlgorithmConfig):\n  lr: float\n  momentum: float = 0.0\n\n\nclass AdagradConfig(OptimizerAlgorithmConfig):\n  lr: float\n  eps: float = 0\n\n\nclass OptimizerConfig(base_config.BaseConfig):\n  learning_rate: LearningRate = pydantic.Field(\n    None,\n    description=\"Constant learning rates\",\n  )\n  adam: AdamConfig = pydantic.Field(None, one_of=\"optimizer\")\n  sgd: SgdConfig = pydantic.Field(None, one_of=\"optimizer\")\n  adagrad: AdagradConfig = pydantic.Field(None, one_of=\"optimizer\")\n\n\ndef get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):\n  if optimizer_config.adam is not None:\n    return optimizer_config.adam\n  elif optimizer_config.sgd is not None:\n    return optimizer_config.sgd\n  elif optimizer_config.adagrad is not None:\n    return optimizer_config.adagrad\n  else:\n    raise ValueError(f\"No optimizer selected in optimizer_config, passed {optimizer_config}\")\n"
  },
  {
    "path": "optimizers/optimizer.py",
    "content": "from typing import Dict, Tuple\nimport math\nimport bisect\n\nfrom tml.optimizers.config import (\n  LearningRate,\n  OptimizerConfig,\n)\n\nimport torch\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import _LRScheduler\nfrom tml.ml_logging.torch_logging import logging\n\n\ndef compute_lr(lr_config, step):\n  \"\"\"Compute a learning rate.\"\"\"\n  if lr_config.constant is not None:\n    return lr_config.constant\n  elif lr_config.piecewise_constant is not None:\n    return lr_config.piecewise_constant.learning_rate_values[\n      bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)\n    ]\n  elif lr_config.linear_ramp_to_constant is not None:\n    slope = (\n      lr_config.linear_ramp_to_constant.learning_rate\n      / lr_config.linear_ramp_to_constant.num_ramp_steps\n    )\n    return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)\n  elif lr_config.linear_ramp_to_cosine is not None:\n    cfg = lr_config.linear_ramp_to_cosine\n    if step < cfg.num_ramp_steps:\n      slope = cfg.learning_rate / cfg.num_ramp_steps\n      return slope * step\n    elif step <= cfg.final_num_steps:\n      return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (\n        1.0\n        + math.cos(\n          math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)\n        )\n      )\n    else:\n      return cfg.final_learning_rate\n  else:\n    raise ValueError(f\"No option selected in lr_config, passed {lr_config}\")\n\n\nclass LRShim(_LRScheduler):\n  \"\"\"Shim to get learning rates into a LRScheduler.\n\n  This adheres to the torch.optim scheduler API and can be plugged anywhere that\n  e.g. exponential decay can be used.\n  \"\"\"\n\n  def __init__(\n    self,\n    optimizer,\n    lr_dict: Dict[str, LearningRate],\n    last_epoch=-1,\n    verbose=False,\n  ):\n    self.optimizer = optimizer\n    self.lr_dict = lr_dict\n    self.group_names = list(self.lr_dict.keys())\n\n    num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)\n    if num_param_groups != len(lr_dict):\n      raise ValueError(\n        f\"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}.\"\n      )\n\n    super().__init__(optimizer, last_epoch, verbose)\n\n  def get_lr(self):\n    if not self._get_lr_called_within_step:\n      logging.warn(\n        \"To get the last learning rate computed by the scheduler, \" \"please use `get_last_lr()`.\",\n        UserWarning,\n      )\n    return self._get_closed_form_lr()\n\n  def _get_closed_form_lr(self):\n    return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]\n\n\ndef get_optimizer_class(optimizer_config: OptimizerConfig):\n  if optimizer_config.adam is not None:\n    return torch.optim.Adam\n  elif optimizer_config.sgd is not None:\n    return torch.optim.SGD\n  elif optimizer_config.adagrad is not None:\n    return torch.optim.Adagrad\n\n\ndef build_optimizer(\n  model: torch.nn.Module, optimizer_config: OptimizerConfig\n) -> Tuple[Optimizer, _LRScheduler]:\n  \"\"\"Builds an optimizer and LR scheduler from an OptimizerConfig.\n  Note: use this when you want the same optimizer and learning rate schedule for all your parameters.\n  \"\"\"\n  optimizer_class = get_optimizer_class(optimizer_config)\n  optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())\n  # We're passing everything in as one group here\n  scheduler = LRShim(optimizer, lr_dict={\"ALL_PARAMS\": optimizer_config.learning_rate})\n  return optimizer, scheduler\n"
  },
  {
    "path": "projects/__init__.py",
    "content": ""
  },
  {
    "path": "projects/home/recap/FEATURES.md",
    "content": "# Overview\nBelow is a description of the major feature groups which are input to the Twitter Heavy Ranking model.\n\nNote that not every request will have every feature available due to user settings or other constraints and there may be some differences in ranking \"For You\" based on different variables.\n\n## Aggregate Features\nTwitter's aggregate features comprise the bulk of Twitter's feature count and are generated by maintaining rolling aggregations of feature values within a specific scope within a specific time window. We compute aggregates over the long-term (50 days count) and short-term (\"real-time\" - under 3 days count and typically 30 mins count).\n\n<details>\n<summary><b>Show Details</b></summary>\nAggregate features are groups of multiple features generated as Cartesian crosses from a template and have the format\n<table>\n<tr>\n<td><b>Feature Group Name</b></td>\n<td><b>Engagement Scope</b></td>\n<td><b>Feature To Aggregate</b></td>\n<td><b>Aggregation Spec</b></td>\n</tr>\n</table>\n\n<ul>\n<li> The <b>Feature Group Name</b> is both the name of the aggregate feature and contains internally the aggregation scope, that is, what entities are aggregated over. \n<ul>\n<li> For example, <code>\"user_aggregate\"</code> aggregates over unique user_ids, and <code>\"user_author_aggregate\"</code> aggregates over all user-author pairs. It also determines what fields the feature is joined to when being used. In the case of <code>\"user_author_aggregate\"</code>, the feature is joined to data corresponding to the specific user and the specific author. \n<li> The raw feature group names are often verbose and are simplified in the below presentation.\n</ul>\n<li> <b>Engagement Scope</b> is the subset of tweets within the aggregation scope that will be aggregated over. Typically this is the name of an output engagement, like <code>recap.engagement.is_favorited</code>. In that case, we only aggregate over Tweets which are also Liked.\n<li> The <b>Feature To Aggregate</b> is the feature we are accumulating over. If this value is <code>any_feature</code>, that means we aggregate the Tweet count.  For example <code>user_aggregate_v2.pair.recap.engagement.is_favorited.any_feature.50.days.count</code> will be the number of Liked records for every user over the last 50 days.\n<li> The <b>Aggregation Spec</b> is what aggregate to compute - what function and over what time window.\n</ul>\n\nFor every Feature Group, we generate one feature for every possible combination of Engagement Scope, Feature To Aggregate, and Aggregation Spec. In particular, every row in the below tables generate one feature for every possible cross between columns.\n\n<b>Example</b>:\nFor example, one such feature may be <code>user_aggregate_v2.pair.recap.engagement.is_favorited.engagement_features.in_network.replies.count.50.days.count</code>, which can be parsed into\n<table>\n<tr>\n<td><b>Feature Group Name</b></td>\n<td><b>Engagement Scope</b></td>\n<td><b>Feature To Aggregate</b></td>\n<td><b>Aggregation Spec</b></td>\n</tr>\n<tr>\n<td><code>user_aggregate_v2.pair</code></td>\n<td><code>recap.engagement.is_favorited</code></td>\n<td><code>engagement_features.in_network.replies.count</code></td>\n<td><code>50.days.count</code></td>\n</tr>\n</table>\n\nThis means that this feature aggregates\n<ol>\n<li> (Over every user),\n<li> (Over only tweets favorited by the user),\n<li> In network replies sent out by this user,\n<li> (Counted over the last 50 days)\n</ol>\nThis feature is then made available as a feature for the particular user. \n\n</details>\n\nThe list of our aggregate features are below:\n<details>\n<summary><b><code>author_aggregate</code></b></summary>\nThese features aggregate over the author (or original author) of a tweet. Some of the features are short-duration (30 minutes) and some longer (50 days). The features track how many of an author's tweets were engaged with.\n<br>\n<table>\n<tr>\n<td>\n<code>\nauthor (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.enagagement.is_retweeted_without_quote <br>\ntimelines.engagement.is_clicked <br>\ntimelines.engagement.is_dont_like <br>\ntimelines.engagement.is_dwelled <br>\ntimelines.engagement.is_favorited <br>\ntimelines.engagement.is_followed <br>\ntimelines.engagement.is_open_linked <br>\ntimelines.engagement.is_photo_expanded <br>\ntimelines.engagement.is_profile_clicked <br>\ntimelines.engagement.is_quoted <br>\ntimelines.engagement.is_replied <br>\ntimelines.engagement.is_retweeted <br>\ntimelines.engagement.is_tweet_share_dm_clicked <br>\ntimelines.engagement.is_tweet_share_dm_sent <br>\ntimelines.engagement.is_video_playback_50 <br>\ntimelines.engagement.is_video_quality_viewed <br>\ntimelines.engagement.is_video_viewed <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\noriginal_author (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.enagagement.is_retweeted_without_quote <br>\ntimelines.engagement.is_clicked <br>\ntimelines.engagement.is_dont_like <br>\ntimelines.engagement.is_dwelled <br>\ntimelines.engagement.is_favorited <br>\ntimelines.engagement.is_followed <br>\ntimelines.engagement.is_open_linked <br>\ntimelines.engagement.is_photo_expanded <br>\ntimelines.engagement.is_profile_clicked <br>\ntimelines.engagement.is_quoted <br>\ntimelines.engagement.is_replied <br>\ntimelines.engagement.is_retweeted <br>\ntimelines.engagement.is_tweet_share_dm_clicked <br>\ntimelines.engagement.is_tweet_share_dm_sent <br>\ntimelines.engagement.is_video_playback_50 <br>\ntimelines.engagement.is_video_quality_viewed <br>\ntimelines.engagement.is_video_viewed <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count\n</code>\n</td>\n</tr>\n\n\n<tr>\n<td>\n<code>\noriginal_author (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_share_menu_clicked <br>\ntimelines.engagement.is_shared <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count <br>\n1.days.count <br>\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\noriginal_author\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_replied_reply_favorited_by_author <br>\nrecap.engagement.is_replied_reply_impressed_by_author <br>\nrecap.engagement.is_replied_reply_replied_by_author <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n\n</table>\n</details>\n\n\n<details>\n<summary><b><code>author-topic_aggregate</code></b></summary>\nThese features aggregate over a specific tweet author and a specific topic. We only accumulate long (50 day) counts. \n<br>\n<table>\n<tr>\n<td>\n<code>\nauthor-topic\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n\n</table>\n</details>\n\n<details>\n<summary><b><code>list_aggregate</code></b></summary>\nThese features aggregate short term and long term engagement between a user and a list.\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_list\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nlist (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_block_clicked <br>\ntimelines.engagement.is_dont_like <br>\ntimelines.engagement.is_dwelled <br>\ntimelines.engagement.is_favorited <br>\ntimelines.engagement.is_mute_clicked <br>\ntimelines.engagement.is_replied <br>\ntimelines.engagement.is_report_tweet_clicked <br>\ntimelines.engagement.is_retweeted <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count\n</code>\n</td>\n</tr>\n\n</table>\n</details>\n\n\n<details>\n<summary><b><code>user_aggregate</code></b></summary>\nThese features aggregate short term and long term engagement from a specific user. \n\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_v2\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\nengagement_features.in_network.favorites.count <br>\nengagement_features.in_network.replies.count <br>\nengagement_features.in_network.retweets.count <br>\nrealgraph.num_favorites.days_since_last <br>\nrealgraph.num_favorites.elapsed_days <br>\nrealgraph.num_favorites.ewma <br>\nrealgraph.num_favorites.non_zero_days <br>\nrealgraph.num_inspected_tweets.days_since_last <br>\nrealgraph.num_inspected_tweets.elapsed_days <br>\nrealgraph.num_inspected_tweets.ewma <br>\nrealgraph.num_inspected_tweets.non_zero_days <br>\nrealgraph.num_mentions.days_since_last <br>\nrealgraph.num_mentions.elapsed_days <br>\nrealgraph.num_mentions.ewma <br>\nrealgraph.num_mentions.non_zero_days <br>\nrealgraph.num_profile_views.days_since_last <br>\nrealgraph.num_profile_views.elapsed_days <br>\nrealgraph.num_profile_views.ewma <br>\nrealgraph.num_profile_views.non_zero_days <br>\nrealgraph.num_retweets.days_since_last <br>\nrealgraph.num_retweets.elapsed_days <br>\nrealgraph.num_retweets.ewma <br>\nrealgraph.num_retweets.non_zero_days <br>\nrealgraph.num_tweet_clicks.days_since_last <br>\nrealgraph.num_tweet_clicks.elapsed_days <br>\nrealgraph.num_tweet_clicks.ewma <br>\nrealgraph.num_tweet_clicks.non_zero_days <br>\nrealgraph.total_dwell_time.days_since_last <br>\nrealgraph.total_dwell_time.elapsed_days <br>\nrealgraph.total_dwell_time.ewma <br>\nrealgraph.total_dwell_time.non_zero_days <br>\nrecap.earlybird.fav_count_v2 <br>\nrecap.earlybird.reply_count_v2 <br>\nrecap.earlybird.retweet_count_v2 <br>\nrecap.searchfeature.blender_score <br>\nrecap.searchfeature.fav_count <br>\nrecap.searchfeature.reply_count <br>\nrecap.searchfeature.retweet_count <br>\nrecap.searchfeature.text_score <br>\nrecap.tweetfeature.bidirectional_fav_count <br>\nrecap.tweetfeature.bidirectional_reply_count <br>\nrecap.tweetfeature.bidirectional_retweet_count <br>\nrecap.tweetfeature.contains_media <br>\nrecap.tweetfeature.conversational_count <br>\nrecap.tweetfeature.embeds_impression_count <br>\nrecap.tweetfeature.embeds_url_count <br>\nrecap.tweetfeature.from_mutual_follow <br>\nrecap.tweetfeature.has_card <br>\nrecap.tweetfeature.has_image <br>\nrecap.tweetfeature.has_link <br>\nrecap.tweetfeature.has_multiple_media <br>\nrecap.tweetfeature.has_news <br>\nrecap.tweetfeature.has_periscope <br>\nrecap.tweetfeature.has_pro_video <br>\nrecap.tweetfeature.has_trend <br>\nrecap.tweetfeature.has_video <br>\nrecap.tweetfeature.has_vine <br>\nrecap.tweetfeature.has_visible_link <br>\nrecap.tweetfeature.is_business_score <br>\nrecap.tweetfeature.is_extended_reply <br>\nrecap.tweetfeature.is_reply <br>\nrecap.tweetfeature.is_retweet <br>\nrecap.tweetfeature.is_sensitive <br>\nrecap.tweetfeature.link_count <br>\nrecap.tweetfeature.link_language <br>\nrecap.tweetfeature.match_searcher_langs <br>\nrecap.tweetfeature.match_searcher_main_lang <br>\nrecap.tweetfeature.match_ui_lang <br>\nrecap.tweetfeature.mention_searcher <br>\nrecap.tweetfeature.num_hashtags <br>\nrecap.tweetfeature.num_mentions <br>\nrecap.tweetfeature.reply_other <br>\nrecap.tweetfeature.reply_searcher <br>\nrecap.tweetfeature.retweet_other <br>\nrecap.tweetfeature.retweet_searcher <br>\nrecap.tweetfeature.tweet_count_from_user_in_snapshot <br>\nrecap.tweetfeature.unidirectiona_fav_count <br>\nrecap.tweetfeature.unidirectional_reply_count <br>\nrecap.tweetfeature.unidirectional_retweet_count <br>\nrecap.tweetfeature.user_rep <br>\nrecap.tweetfeature.video_view_count <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count<br>\n50.days.sum<br>\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_v5\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked<br>\nrecap.engagement.is_favorited<br>\nrecap.engagement.is_open_linked<br>\nrecap.engagement.is_photo_expanded<br>\nrecap.engagement.is_profile_clicked<br>\nrecap.engagement.is_replied<br>\nrecap.engagement.is_retweeted<br>\nrecap.engagement.is_video_playback_50<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\ntime_features.earlybird.last_favorite_since_creation_hrs<br>\ntime_features.earlybird.last_quote_since_creation_hrs<br>\ntime_features.earlybird.last_reply_since_creation_hrs<br>\ntime_features.earlybird.last_retweet_since_creation_hrs<br>\ntime_features.earlybird.time_since_last_favorite<br>\ntime_features.earlybird.time_since_last_quote<br>\ntime_features.earlybird.time_since_last_reply<br>\ntime_features.earlybird.time_since_last_retweet<br>\ntimelines.earlybird.decayed_favorite_count<br>\ntimelines.earlybird.decayed_quote_count<br>\ntimelines.earlybird.decayed_reply_count<br>\ntimelines.earlybird.decayed_retweet_count<br>\ntimelines.earlybird.embeds_impression_count_v2<br>\ntimelines.earlybird.embeds_url_count_v2<br>\ntimelines.earlybird.fake_favorite_count<br>\ntimelines.earlybird.fake_quote_count<br>\ntimelines.earlybird.fake_reply_count<br>\ntimelines.earlybird.fake_retweet_count<br>\ntimelines.earlybird.quote_count<br>\ntimelines.earlybird.visible_token_ratio<br>\ntimelines.earlybird.weighted_fav_count<br>\ntimelines.earlybird.weighted_quote_count<br>\ntimelines.earlybird.weighted_reply_count<br>\ntimelines.earlybird.weighted_retweet_count<br>\n</code>\n</td>\n<td>\n<code>\n50.days.count<br>\n50.days.sum<br>\n50.days.sumsq<br>\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser_v6\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_replied_reply_favorited_by_author<br>\nrecap.engagement.is_replied_reply_impressed_by_author<br>\nrecap.engagement.is_replied_reply_replied_by_author<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser (twitter_wide)\n</code>\n</td>\n<td>\n<code>\nany_label<br>\nrecap.engagement.is_favorited<br>\nrecap.engagement.is_replied<br>\nrecap.engagement.is_retweeted<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\nrecap.tweetfeature.contains_media<br>\nrecap.tweetfeature.has_card<br>\nrecap.tweetfeature.has_hashtag<br>\nrecap.tweetfeature.has_link<br>\nrecap.tweetfeature.has_mention<br>\nrecap.tweetfeature.is_reply<br>\ntimelines.earlybird.has_quote<br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n\n\n<tr>\n<td>\n<code>\nuser (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.enagagement.is_retweeted_without_quote<br>\ntimelines.engagement.is_clicked<br>\ntimelines.engagement.is_dont_like<br>\ntimelines.engagement.is_dwelled<br>\ntimelines.engagement.is_favorited<br>\ntimelines.engagement.is_followed<br>\ntimelines.engagement.is_open_linked<br>\ntimelines.engagement.is_photo_expanded<br>\ntimelines.engagement.is_profile_clicked<br>\ntimelines.engagement.is_quoted<br>\ntimelines.engagement.is_replied<br>\ntimelines.engagement.is_retweeted<br>\ntimelines.engagement.is_tweet_share_dm_clicked<br>\ntimelines.engagement.is_tweet_share_dm_sent<br>\ntimelines.engagement.is_video_playback_50<br>\ntimelines.engagement.is_video_quality_viewed<br>\ntimelines.engagement.is_video_viewed<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\nclient_log_event.tweet.has_consumer_video<br>\nclient_log_event.tweet.photo_count<br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser (48h_real_time_v5)\n</code>\n</td>\n<td>\n<code>\ntimelines.enagagement.is_retweeted_without_quote<br>\ntimelines.engagement.is_clicked<br>\ntimelines.engagement.is_dont_like<br>\ntimelines.engagement.is_dwelled<br>\ntimelines.engagement.is_favorited<br>\ntimelines.engagement.is_followed<br>\ntimelines.engagement.is_open_linked<br>\ntimelines.engagement.is_photo_expanded<br>\ntimelines.engagement.is_profile_clicked<br>\ntimelines.engagement.is_quoted<br>\ntimelines.engagement.is_replied<br>\ntimelines.engagement.is_retweeted<br>\ntimelines.engagement.is_tweet_share_dm_clicked<br>\ntimelines.engagement.is_tweet_share_dm_sent<br>\ntimelines.engagement.is_video_playback_50<br>\ntimelines.engagement.is_video_quality_viewed<br>\ntimelines.engagement.is_video_viewed<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\nclient_log_event.tweet.has_consumer_video<br>\nclient_log_event.tweet.photo_count<br>\n</code>\n</td>\n<td>\n<code>\n2.days.count\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser (72h_real_time_v6)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_block_clicked<br>\ntimelines.engagement.is_dont_like<br>\ntimelines.engagement.is_mute_clicked<br>\ntimelines.engagement.is_report_tweet_clicked<br>\n</code>\n</td>\n<td>\n<code>\ntimelines.author.user_state.is_user_heavy_non_tweeter<br>\ntimelines.author.user_state.is_user_heavy_tweeter<br>\ntimelines.author.user_state.is_user_light<br>\ntimelines.author.user_state.is_user_medium_non_tweeter<br>\ntimelines.author.user_state.is_user_medium_tweeter<br>\ntimelines.author.user_state.is_user_new<br>\n</code>\n</td>\n<td>\n<code>\n3.days.count\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser (profile_real_time_v6)\n</code>\n</td>\n<td>\n<code>\nprofile.engagement.is_clicked<br>\nprofile.engagement.is_dwelled<br>\nprofile.engagement.is_favorited<br>\nprofile.engagement.is_replied<br>\nprofile.engagement.is_retweeted<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\nclient_log_event.tweet.has_consumer_video<br>\nclient_log_event.tweet.photo_count<br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_share_menu_clicked<br>\ntimelines.engagement.is_shared  <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\nclient_log_event.tweet.has_consumer_video<br>\nclient_log_event.tweet.photo_count<br>\n</code>\n</td>\n<td>\n<code>\n1.days.count<br>\n30.minutes.count<br>\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_fullscreen_video_dwelled<br>\ntimelines.engagement.is_fullscreen_video_dwelled_10_sec<br>\ntimelines.engagement.is_fullscreen_video_dwelled_20_sec<br>\ntimelines.engagement.is_fullscreen_video_dwelled_30_sec<br>\ntimelines.engagement.is_fullscreen_video_dwelled_5_sec<br>\ntimelines.engagement.is_profile_dwelled<br>\ntimelines.engagement.is_profile_dwelled_10_sec<br>\ntimelines.engagement.is_profile_dwelled_20_sec<br>\ntimelines.engagement.is_profile_dwelled_30_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled<br>\ntimelines.engagement.is_tweet_detail_dwelled_15_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled_25_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled_30_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled_8_sec<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n1.days.count<br>\n30.minutes.count<br>\n</code>\n</td>\n</tr>\n\n</table>\n</details>\n\n<details>\n<summary><b><code>user_author_aggregate</code></b></summary>\nThese features aggregate over user-author pairs.\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_author_v2\n</code>\n</td>\n<td>\n<code>\nany_label<br>\nrecap.engagement.is_clicked<br>\nrecap.engagement.is_favorited<br>\nrecap.engagement.is_open_linked<br>\nrecap.engagement.is_photo_expanded<br>\nrecap.engagement.is_profile_clicked<br>\nrecap.engagement.is_replied<br>\nrecap.engagement.is_retweeted<br>\nrecap.engagement.is_video_playback_50<br>\n</code>\n</td>\n<td>\n<code>\nengagement_features.in_network.favorites.count<br>\nengagement_features.in_network.replies.count<br>\nengagement_features.in_network.retweets.count<br>\nrecap.earlybird.fav_count_v2<br>\nrecap.earlybird.reply_count_v2<br>\nrecap.earlybird.retweet_count_v2<br>\nrecap.searchfeature.blender_score<br>\nrecap.searchfeature.fav_count<br>\nrecap.searchfeature.reply_count<br>\nrecap.searchfeature.retweet_count<br>\nrecap.searchfeature.text_score<br>\nrecap.tweetfeature.embeds_impression_count<br>\nrecap.tweetfeature.embeds_url_count<br>\nrecap.tweetfeature.has_card<br>\nrecap.tweetfeature.has_image<br>\nrecap.tweetfeature.has_link<br>\nrecap.tweetfeature.has_multiple_media<br>\nrecap.tweetfeature.has_news<br>\nrecap.tweetfeature.has_periscope<br>\nrecap.tweetfeature.has_pro_video<br>\nrecap.tweetfeature.has_trend<br>\nrecap.tweetfeature.has_video<br>\nrecap.tweetfeature.has_vine<br>\nrecap.tweetfeature.has_visible_link<br>\nrecap.tweetfeature.is_reply<br>\nrecap.tweetfeature.is_retweet<br>\nrecap.tweetfeature.num_mentions<br>\n</code>\n</td>\n<td>\n<code>\n50.days.count<br>\n50.days.sum<br>\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_author_v5\n</code>\n</td>\n<td>\n<code>\nany_label<br>\nrecap.engagement.is_clicked<br>\nrecap.engagement.is_favorited<br>\nrecap.engagement.is_open_linked<br>\nrecap.engagement.is_photo_expanded<br>\nrecap.engagement.is_profile_clicked<br>\nrecap.engagement.is_replied<br>\nrecap.engagement.is_retweeted<br>\nrecap.engagement.is_video_playback_50<br>\n</code>\n</td>\n<td>\n<code>\nany_feature<br>\ntimelines.earlybird.has_quote<br>\ntimelines.earlybird.label_abusive_flag<br>\ntimelines.earlybird.label_abusive_hi_rcl_flag<br>\ntimelines.earlybird.label_dup_content_flag<br>\ntimelines.earlybird.label_nsfw_hi_prc_flag<br>\ntimelines.earlybird.label_nsfw_hi_rcl_flag<br>\ntimelines.earlybird.label_spam_flag<br>\ntimelines.earlybird.label_spam_hi_rcl_flag<br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_author (tweetsource_v1 - <br>\nThese features are sourced from a different underlying dataset)\n</code>\n</td>\n<td>\n<code>\nany_label<br>\nrecap.engagement.is_clicked<br>\nrecap.engagement.is_favorited<br>\nrecap.engagement.is_open_linked<br>\nrecap.engagement.is_photo_expanded<br>\nrecap.engagement.is_profile_clicked<br>\nrecap.engagement.is_replied<br>\nrecap.engagement.is_retweeted<br>\nrecap.engagement.is_video_playback_50<br>\n</code>\n</td>\n<td>\n<code>\nany_feature<br>\ntweetsource.tweet.media.num_tags<br>\ntweetsource.tweet.media.video_duration<br>\ntweetsource.tweet.text.has_question<br>\ntweetsource.tweet.text.length<br>\n</code>\n</td>\n<td>\n<code>\n50.days.count<br>\n50.days.sum<br>\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_author (twitter_wide - <br>\nThese features are sourced from a different underlying dataset)\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_favorited<br>\nrecap.engagement.is_replied<br>\nrecap.engagement.is_retweeted<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\nrecap.tweetfeature.contains_media<br>\nrecap.tweetfeature.has_card<br>\nrecap.tweetfeature.has_hashtag<br>\nrecap.tweetfeature.has_link<br>\nrecap.tweetfeature.has_mention<br>\nrecap.tweetfeature.is_reply<br>\ntimelines.earlybird.has_quote<br>\n</code>\n</td>\n<td>\n<code>\n50.days.count<br>\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_original_author (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_shared<br>\n</code>\n</td>\n<td>\n<code>\nany_feature<br>\n</code>\n</td>\n<td>\n<code>\n1.days.count<br>\n30.minutes.count<br>\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\nuser_original_author\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_replied_reply_favorited_by_author<br>\nrecap.engagement.is_replied_reply_impressed_by_author<br>\nrecap.engagement.is_replied_reply_replied_by_author<br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_author (real_time, shared)\n</code>\n<td>\n<code>\ntimelines.engagement.is_clicked<br>\ntimelines.engagement.is_dwelled<br>\ntimelines.engagement.is_favorited<br>\ntimelines.engagement.is_negative_feedback_union<br>\ntimelines.engagement.is_photo_expanded<br>\ntimelines.engagement.is_profile_clicked<br>\ntimelines.engagement.is_replied<br>\ntimelines.engagement.is_retweeted<br>\ntimelines.engagement.is_share_menu_clicked<br>\ntimelines.engagement.is_video_playback_50\n</code>\n</td>\n<td>\n<code>\nany_feature\n</code>\n</td>\n<td>\n<code>\n1.days.count<br>\n30.minutes.count\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n\n\n<details>\n<summary><b><code>user_engager_aggregate</code></b></summary>\nThese features aggregate counts of user interaction with other engagers of tweets that the user interacts with.\n\nFor example, the <code>user_engager.recap.engagement.is_favorited.any_feature.50.days.count.sparse_top1</code> feature can be parsed as follows: \n\nFor all tweets that a user Likes, accumulate a running count over 50 days where the number of engagement events for every other user who has engaged with the Tweet is accumulated. Engagement is defined as Like or reply. We now have a list of engagement counts for other users that have engaged with the Tweets that the user has Liked, and we take the top count as the feature value.  \n\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_engager <br>\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count.sparse_mean <br>\n50.days.count.sparse_nonzero <br>\n50.days.count.sparse_sum <br>\n50.days.count.sparse_top1 <br>\n50.days.count.sparse_top2 <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n\n<details>\n<summary><b><code>user_inferred_topic_aggregate</code></b></summary>\nThese features aggregate short term and long term engagement between a user and tweets from our internally predicted inferred topic (whether or not the tweet is actually tagged to that topic).\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_inferred_topic_v1\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count.sparse_mean <br>\n50.days.count.sparse_nonzero <br>\n50.days.count.sparse_sum <br>\n50.days.count.sparse_top1 <br>\n50.days.count.sparse_top2 <br>\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_inferred_topic_v2\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nengagement_features.in_network.favorites.count <br>\nengagement_features.in_network.retweets.count <br>\nrecap.searchfeature.fav_count <br>\nrecap.tweetfeature.contains_media <br>\nrecap.tweetfeature.has_card <br>\nrecap.tweetfeature.has_image <br>\nrecap.tweetfeature.has_link <br>\nrecap.tweetfeature.has_news <br>\nrecap.tweetfeature.has_trend <br>\nrecap.tweetfeature.has_video <br>\nrecap.tweetfeature.is_reply <br>\nrecap.tweetfeature.is_retweet <br>\nrecap.tweetfeature.is_sensitive <br>\nrecap.tweetfeature.match_searcher_langs <br>\nrecap.tweetfeature.match_searcher_main_lang <br>\nrecap.tweetfeature.match_ui_lang <br>\nrecap.tweetfeature.mention_searcher <br>\nrecap.tweetfeature.reply_other <br>\nrecap.tweetfeature.reply_searcher <br>\nrecap.tweetfeature.retweet_other <br>\nrecap.tweetfeature.retweet_searcher <br>\ntweetsource.tweet.media.aspect_ratio_den <br>\ntweetsource.tweet.text.num_caps <br>\ntweetsource.tweet.text.num_newlines <br>\ntweetsource.v2.tweet.media.has_description <br>\ntweetsource.v2.tweet.media.has_selected_preview_image <br>\ntweetsource.v2.tweet.media.has_title <br>\ntweetsource.v2.tweet.media.has_visit_site_call_to_action <br>\ntweetsource.v2.tweet.media.has_watch_now_call_to_action <br>\ntweetsource.v2.tweet.media.is_360 <br>\ntweetsource.v2.tweet.media.is_managed <br>\ntweetsource.v2.tweet.media.is_monetizable <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count.sparse_mean <br>\n50.days.count.sparse_nonzero <br>\n50.days.count.sparse_sum <br>\n50.days.count.sparse_top1 <br>\n50.days.count.sparse_top2 <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n\n<details>\n<summary><b><code>user_media_annotation_aggregate</code></b></summary>\nThese features aggregate how often a user interacts with different types of media (photo, video, etc)\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_media_annotation\n(keyed by user and media type)\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nany_feature\n</code>\n</td>\n<td>\n<code>\n50.days.count.sparse_mean <br>\n50.days.count.sparse_nonzero <br>\n50.days.count.sparse_sum <br>\n50.days.count.sparse_top1 <br>\n50.days.count.sparse_top2 <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n<details>\n<summary><b><code>user_mention_aggregate</code></b></summary>\nThese features aggregate counts of user interactions with Tweets that mention other users.\n\nLet the original user who viewed a Tweet be <code>user1</code>, and let <code>user2, user3, ...,  user_n</code> be users mentioned in a tweet. This feature group aggregates the interactions between <code>user1</code> and other Tweets that mention <code>user2, user3,..., user_n</code>.\n\nHere <code>sparse_sum</code> means we sum the aggregate values over all mentioned users, <code>sparse_top1</code> means we take the max of the aggregate values for the mentioned authors, <code>sparse_top1</code> means we take the second-highest of the aggregate values for the mentioned authors, and so on.\n\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_mention <br>\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 any_feature.50.days.count <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count.sparse_mean <br>\n50.days.count.sparse_nonzero <br>\n50.days.count.sparse_sum <br>\n50.days.count.sparse_top1 <br>\n50.days.count.sparse_top2 <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n\n<details>\n<summary><b><code>user_request_context_aggregate</code></b></summary>\nThese features aggregate engagements over the request context, which is either the same day of week (dow) or hour of day (hour), to account for temporal effects.\n<br>\n<table>\n<tr>\n<td>\n<code>\ndow <br>\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded  <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count <br>\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nhour <br>\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded  <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n\n<details>\n<summary><b><code>user_topic_aggregate</code></b></summary>\nThese features aggregate long term feature values between a user and tweets from a particular topic.\n<br>\n<table>\n<tr>\n<td>\n<code>\nuser_topic_v1\n</code>\n</td>\n<td>\n<code>\nany_label <br>\nrecap.engagement.is_clicked <br>\nrecap.engagement.is_favorited <br>\nrecap.engagement.is_open_linked <br>\nrecap.engagement.is_photo_expanded <br>\nrecap.engagement.is_profile_clicked <br>\nrecap.engagement.is_replied <br>\nrecap.engagement.is_retweeted <br>\nrecap.engagement.is_video_playback_50 <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nuser_topic_v2\n</code>\n</td>\n<td>\n<code>\nrecap.engagement.is_clicked  <br>\nrecap.engagement.is_favorited  <br>\nrecap.engagement.is_open_linked  <br>\nrecap.engagement.is_photo_expanded  <br>\nrecap.engagement.is_profile_clicked  <br>\nrecap.engagement.is_replied  <br>\nrecap.engagement.is_retweeted  <br>\nrecap.engagement.is_video_playback_50  <br>\n</code>\n</td>\n<td>\n<code>\nengagement_features.in_network.favorites.count  <br>\nengagement_features.in_network.retweets.count  <br>\nrecap.searchfeature.fav_count  <br>\nrecap.tweetfeature.contains_media  <br>\nrecap.tweetfeature.has_card  <br>\nrecap.tweetfeature.has_image  <br>\nrecap.tweetfeature.has_link  <br>\nrecap.tweetfeature.has_news  <br>\nrecap.tweetfeature.has_trend  <br>\nrecap.tweetfeature.has_video  <br>\nrecap.tweetfeature.is_reply  <br>\nrecap.tweetfeature.is_retweet  <br>\nrecap.tweetfeature.is_sensitive  <br>\nrecap.tweetfeature.match_searcher_langs  <br>\nrecap.tweetfeature.match_searcher_main_lang  <br>\nrecap.tweetfeature.match_ui_lang  <br>\nrecap.tweetfeature.mention_searcher  <br>\nrecap.tweetfeature.reply_other  <br>\nrecap.tweetfeature.reply_searcher  <br>\nrecap.tweetfeature.retweet_other  <br>\nrecap.tweetfeature.retweet_searcher  <br>\ntweetsource.tweet.media.aspect_ratio_den  <br>\ntweetsource.tweet.text.num_caps  <br>\ntweetsource.tweet.text.num_newlines  <br>\ntweetsource.v2.tweet.media.has_description  <br>\ntweetsource.v2.tweet.media.has_selected_preview_image  <br>\ntweetsource.v2.tweet.media.has_title  <br>\ntweetsource.v2.tweet.media.has_visit_site_call_to_action  <br>\ntweetsource.v2.tweet.media.has_watch_now_call_to_action  <br>\ntweetsource.v2.tweet.media.is_360  <br>\ntweetsource.v2.tweet.media.is_managed  <br>\ntweetsource.v2.tweet.media.is_monetizable  <br>\n</code>\n</td>\n<td>\n<code>\n50.days.count\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n\n<details>\n<summary><b><code>topic_aggregate</code></b></summary>\nThese features aggregate values for tweets that come from a particular topic.\n<br>\n<table>\n<tr>\n<td>\n<code>\ntopic (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.enagagement.is_retweeted_without_quote <br>\ntimelines.engagement.is_clicked <br>\ntimelines.engagement.is_dont_like <br>\ntimelines.engagement.is_dwelled <br>\ntimelines.engagement.is_favorited <br>\ntimelines.engagement.is_followed <br>\ntimelines.engagement.is_not_interested_in_topic <br>\ntimelines.engagement.is_open_linked <br>\ntimelines.engagement.is_photo_expanded <br>\ntimelines.engagement.is_profile_clicked <br>\ntimelines.engagement.is_quoted <br>\ntimelines.engagement.is_replied <br>\ntimelines.engagement.is_retweeted <br>\ntimelines.engagement.is_tweet_share_dm_clicked <br>\ntimelines.engagement.is_tweet_share_dm_sent <br>\ntimelines.engagement.is_video_playback_50 <br>\ntimelines.engagement.is_video_quality_viewed <br>\ntimelines.engagement.is_video_viewed <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\ntopic (24_hour_real_time)\n</code>\n</td>\n<td>\n<code>timelines.enagagement.is_retweeted_without_quote<br>\ntimelines.engagement.is_block_clicked<br>\ntimelines.engagement.is_clicked<br>\ntimelines.engagement.is_dont_like<br>\ntimelines.engagement.is_dwelled<br>\ntimelines.engagement.is_favorited<br>\ntimelines.engagement.is_followed<br>\ntimelines.engagement.is_mute_clicked<br>\ntimelines.engagement.is_not_about_topic<br>\ntimelines.engagement.is_not_interested_in_topic<br>\ntimelines.engagement.is_not_recent<br>\ntimelines.engagement.is_not_relevant<br>\ntimelines.engagement.is_open_linked<br>\ntimelines.engagement.is_photo_expanded<br>\ntimelines.engagement.is_profile_clicked<br>\ntimelines.engagement.is_quoted<br>\ntimelines.engagement.is_replied<br>\ntimelines.engagement.is_report_tweet_clicked<br>\ntimelines.engagement.is_retweeted<br>\ntimelines.engagement.is_see_fewer<br>\ntimelines.engagement.is_tweet_share_dm_clicked<br>\ntimelines.engagement.is_tweet_share_dm_sent<br>\ntimelines.engagement.is_unfollow_topic<br>\ntimelines.engagement.is_video_playback_50<br>\ntimelines.engagement.is_video_quality_viewed<br>\ntimelines.engagement.is_video_viewed\n</code></td>\n<td><code>any_feature</code></td>\n<td><code>1.days.count</code></td>\n</tr>\n<tr>\n<td>\n<code>\ntopic-country_code (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_block_clicked<br>\ntimelines.engagement.is_clicked<br>\ntimelines.engagement.is_dont_like<br>\ntimelines.engagement.is_dwelled<br>\ntimelines.engagement.is_favorited<br>\ntimelines.engagement.is_impressed<br>\ntimelines.engagement.is_mute_clicked<br>\ntimelines.engagement.is_not_about_topic<br>\ntimelines.engagement.is_not_interested_in_topic<br>\ntimelines.engagement.is_not_recent<br>\ntimelines.engagement.is_not_relevant<br>\ntimelines.engagement.is_open_linked<br>\ntimelines.engagement.is_photo_expanded<br>\ntimelines.engagement.is_profile_clicked<br>\ntimelines.engagement.is_replied<br>\ntimelines.engagement.is_report_tweet_clicked<br>\ntimelines.engagement.is_retweeted<br>\ntimelines.engagement.is_see_fewer<br>\ntimelines.engagement.is_share_menu_clicked<br>\ntimelines.engagement.is_shared<br>\ntimelines.engagement.is_unfollow_topic<br>\ntimelines.engagement.is_video_playback_50<br>\ntimelines.engagement.is_video_quality_viewed\n</code>\n</td>\n<td><code>any_feature</code></td>\n<td><code>3.days.count<br>30.minutes.count</code></td>\n</tr>\n<tr>\n<td>\n<code>\ntopic-share (real_time)\n</code>\n</td>\n<td>\n<code>\ntimelines.engagement.is_share_menu_clicked<br>\ntimelines.engagement.is_shared\n</code>\n</td>\n<td><code>any_feature</code></td>\n<td><code>1.days.count<br>30.minutes.count</code></td>\n</tr>\n</table>\n</details>\n\n<details>\n<summary><b><code>tweet_aggregate</code></b></summary>\nThese features aggregate values corresponding to a tweet.\n<br>\n<table>\n<tr>\n<td><code>tweet (real_time)</code></td>\n<td><code>\ntimelines.enagagement.is_retweeted_without_quote<br>\ntimelines.engagement.is_clicked<br>\ntimelines.engagement.is_dont_like<br>\ntimelines.engagement.is_dwelled<br>\ntimelines.engagement.is_favorited<br>\ntimelines.engagement.is_followed<br>\ntimelines.engagement.is_open_linked<br>\ntimelines.engagement.is_photo_expanded<br>\ntimelines.engagement.is_profile_clicked<br>\ntimelines.engagement.is_quoted<br>\ntimelines.engagement.is_replied<br>\ntimelines.engagement.is_retweeted<br>\ntimelines.engagement.is_tweet_share_dm_clicked<br>\ntimelines.engagement.is_tweet_share_dm_sent<br>\ntimelines.engagement.is_video_playback_50<br>\ntimelines.engagement.is_video_quality_viewed<br>\ntimelines.engagement.is_video_viewed\n</code>\n</td>\n<td><code>any_feature</code></td>\n<td>\n<code>\n30.minutes.count<br>\nDuration.Top.count\n</code>\n</td>\n</tr>\n<tr>\n<td><code>tweet_v2 (real_time)</code></td>\n<td>\n<code>\ntimelines.engagement.is_block_clicked <br>\ntimelines.engagement.is_mute_clicked <br>\ntimelines.engagement.is_report_tweet_clicked <br>\n</code>\n</td>\n<td>\n<code>\nany_feature <br>\n</code>\n</td>\n<td>\n<code>\n30.minutes.count <br>\nDuration.Top.count <br>\n</code>\n</td>\n</tr>\n<tr>\n<td><code>tweet (real_time dwell) </code></td>\n<td><code>timelines.engagement.is_fullscreen_video_dwelled<br>\ntimelines.engagement.is_fullscreen_video_dwelled_10_sec<br>\ntimelines.engagement.is_fullscreen_video_dwelled_20_sec<br>\ntimelines.engagement.is_fullscreen_video_dwelled_30_sec<br>\ntimelines.engagement.is_fullscreen_video_dwelled_5_sec<br>\ntimelines.engagement.is_profile_dwelled<br>\ntimelines.engagement.is_profile_dwelled_10_sec<br>\ntimelines.engagement.is_profile_dwelled_20_sec<br>\ntimelines.engagement.is_profile_dwelled_30_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled<br>\ntimelines.engagement.is_tweet_detail_dwelled_15_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled_25_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled_30_sec<br>\ntimelines.engagement.is_tweet_detail_dwelled_8_sec</code></td>\n<td>\n<code>any_feature\n</code>\n</td>\n<td><code>1.days.count<br>30.minutes.count</code></td>\n</tr>\n<tr>\n<td><code>tweet (real_time shared) </code></td>\n<td>\n<code>\ntimelines.engagement.is_share_menu_clicked<br>\ntimelines.engagement.is_shared\n</code>\n</td>\n<td><code>any_feature</code></td>\n<td><code>1.days.count<br>30.minutes.count</code></td>\n</tr>\n</table>\n</details>\n\n\n## Non Aggregate Features\nWe have a number of standalone features capturing information about the user, the tweet, the author, and the tweet context.\n\n<details>\n<summary><b><code>two_hop</code></b></summary>\n<br>\nThis feature group contains features about interactions which are \"two-hop\" between a user and the tweet author. Examples of two-top interactions are: If user 1</code> favorites a tweet by user 2, and user 2 favorites a tweet by user 3, there will be a positive value for the \"favorite.favorited_by\" two-hop feature between user 1 and user 3.\n\nThe feature group consists of all possible crosses of the below features.\n<table>\n<tr>\n<td>\n<code>\ntwo_hop\n</code>\n</td>\n<td>\n<code>\nfavorite  <br>\nfollowing  <br>\nmutual_follow <br>\n\n</code>\n</td>\n<td>\n<code>\nfavorited_by <br>\nfollowed_by <br>\nmentioned_by <br>\nretweeted_by <br>\n</code>\n</td>\n<td>\n<code>\nnormalized\n</code>\n</td>\n</tr>\n\n<tr>\n<td>\n<code>\ntwo_hop\n</code>\n</td>\n<td>\n<code>\n</code>\n</td>\n<td>\n<code>\nfavorited_by  <br>\nfavorited_by  <br>\nmentioned_by <br>\nretweeted_by\n</code>\n</td>\n<td>\n<code>\nright_degree\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n<details>\n\n<summary><b><code>realgraph</code></b></summary>\n<br>\nThis feature group contains features about interactions between the user and the Tweet author.\n\nThe feature group consists of all possible crosses of the below features.\n<table>\n<tr>\n<td>\n<code>\nrealgraph\n</code>\n</td>\n<td>\n<code>\ndst_id <br>\nsrc_id <br>\n</code>\n</td>\n<td>\n<code>\n</code>\n</td>\n\n</tr>\n<tr>\n<td>\n<code>\nrealgraph\n</code>\n</td>\n<td>\n<code>\nnum_address_book_email <br>\nnum_address_book_in_both <br>\nnum_address_book_mutual_edge_email <br>\nnum_address_book_mutual_edge_in_both <br>\nnum_address_book_mutual_edge_phone <br>\nnum_address_book_phone<br>\nnum_blocks<br>\nnum_direct_messages<br>\nnum_favorites<br>\nnum_follow<br>\nnum_inspected_tweets<br>\nnum_link_clicks<br>\nnum_mentions<br>\nnum_mutes<br>\nnum_mutual_follow<br>\nnum_photo_tags<br>\nnum_profile_views<br>\nnum_report_as_abuses<br>\nnum_report_as_spams<br>\nnum_retweets<br>\nnum_sms_follow<br>\nnum_tweet_clicks<br>\ntotal_dwell_time<br>\nweight\n</code>\n</td>\n<td>\n<code>\ndays_since_last <br>\ndays_since_last.sparse_avg <br>\ndays_since_last.sparse_max <br>\ndays_since_last.sparse_sum <br>\nelapsed_days <br>\nelapsed_days.sparse_avg <br>\nelapsed_days.sparse_max<br>\nelapsed_days.sparse_sum<br>\newma<br>\newma.sparse_avg<br>\newma.sparse_max<br>\newma.sparse_sum<br>\nis_missing<br>\nm2ForVariance.sparse_avg<br>\nm2ForVariance.sparse_max<br>\nm2ForVariance.sparse_sum<br>\nmean<br>\nmean.sparse_avg<br>\nmean.sparse_max<br>\nmean.sparse_sum<br>\nnon_zero_days<br>\nnon_zero_days.sparse_avg<br>\nnon_zero_days.sparse_max<br>\nnon_zero_days.sparse_sum<br>\nsparse_avg<br>\nsparse_max<br>\nsparse_sum<br>\nvariance\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n\n<details>\n<summary><b><code>authors.realgraph</code></b></summary>\nThis feature group contains features about interactions between the user and various other users including \n<ol>\n<li> the Tweet author\n<li>  any users mentioned in the Tweet\n<li>  in-network engagers with the Tweet\n<li>  upstream authors if the Tweet was part of a reply chain\n</ol>\nNote that all the above users are included in the interaction set, not just the Tweet author.\n\nThe feature group consists of all possible crosses of the below features.\n\n<br>\n<table>\n<tr>\n<td>\n<code>\nauthors.realgraph\n</code>\n</td>\n<td>\n<code>\nweight\n</code>\n</td>\n<td>\n<code>\n</code>\n</td>\n<td>\n<code>\nsparse_avg <br>\nsparse_max <br>\nsparse_sum <br>\n</code>\n</td>\n</tr>\n<tr>\n<td>\n<code>\nauthors.realgraph\n</code>\n</td>\n<td>\n<code>\nnum_address_book_email <br>\nnum_address_book_in_both <br>\nnum_address_book_mutual_edge_email <br>\nnum_address_book_mutual_edge_in_both <br>\nnum_address_book_phone <br>\nnum_blocks <br>\nnum_direct_messages <br>\nnum_favorites <br>\nnum_follow <br>\nnum_inspected_tweets <br>\nnum_link_clicks <br>\nnum_mentions <br>\nnum_mutes <br>\nnum_mutual_follow <br>\nnum_photo_tags <br>\nnum_profile_views <br>\nnum_report_as_abuses <br>\nnum_report_as_spams <br>\nnum_retweets <br>\nnum_sms_follow <br>\nnum_tweet_clicks <br>\ntotal_dwell_time <br>\n</code>\n</td>\n<td>\n<code>\ndays_since_last <br>\nelapsed_days <br>\newma <br>\nm2ForVariance <br>\nmean <br>\nnon_zero_days <br>\n</code>\n</td>\n<td>\n<code>\nsparse_avg <br>\nsparse_max <br>\nsparse_sum <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n<details>\n<summary><b><code>recap.tweetfeature, recap.searchfeature, etc</code></b></summary>\n<br>\nThis feature group contains features about the tweet, whether from the tweets service or the search service (\"Earlybird\"). It also contains features related to the user's device type.\n<table>\n<tr>\n<td>\n<code>\nrecap.earlybird.fav_count_v2 <br>\nrecap.earlybird.reply_count_v2 <br>\nrecap.earlybird.retweet_count_v2 <br>\nrecap.searchfeature.blender_score <br>\nrecap.searchfeature.fav_count <br>\nrecap.searchfeature.reply_count <br>\nrecap.searchfeature.retweet_count <br>\nrecap.searchfeature.text_score <br>\nrecap.source.type <br>\nrecap.tweetfeature.bidirectional_fav_count <br>\nrecap.tweetfeature.bidirectional_reply_count <br>\nrecap.tweetfeature.bidirectional_retweet_count <br>\nrecap.tweetfeature.contains_media <br>\nrecap.tweetfeature.conversational_count <br>\nrecap.tweetfeature.embeds_impression_count <br>\nrecap.tweetfeature.embeds_url_count <br>\nrecap.tweetfeature.from_inactive_user <br>\nrecap.tweetfeature.from_mutual_follow <br>\nrecap.tweetfeature.from_verified_account <br>\nrecap.tweetfeature.has_card <br>\nrecap.tweetfeature.has_consumer_video <br>\nrecap.tweetfeature.has_hashtag <br>\nrecap.tweetfeature.has_image <br>\nrecap.tweetfeature.has_link <br>\nrecap.tweetfeature.has_mention <br>\nrecap.tweetfeature.has_multiple_hashtag_or_trend <br>\nrecap.tweetfeature.has_multiple_media <br>\nrecap.tweetfeature.has_native_image <br>\nrecap.tweetfeature.has_native_video <br>\nrecap.tweetfeature.has_news <br>\nrecap.tweetfeature.has_periscope <br>\nrecap.tweetfeature.has_pro_video <br>\nrecap.tweetfeature.has_trend <br>\nrecap.tweetfeature.has_video <br>\nrecap.tweetfeature.has_vine <br>\nrecap.tweetfeature.has_visible_link <br>\nrecap.tweetfeature.is_author_bot <br>\nrecap.tweetfeature.is_author_new <br>\nrecap.tweetfeature.is_author_profile_egg <br>\nrecap.tweetfeature.is_author_spam <br>\nrecap.tweetfeature.is_business_score <br>\nrecap.tweetfeature.is_extended_reply <br>\nrecap.tweetfeature.is_offensive <br>\nrecap.tweetfeature.is_reply <br>\nrecap.tweetfeature.is_retweet <br>\nrecap.tweetfeature.is_sensitive <br>\nrecap.tweetfeature.language <br>\nrecap.tweetfeature.link_count <br>\nrecap.tweetfeature.link_language <br>\nrecap.tweetfeature.match_searcher_langs <br>\nrecap.tweetfeature.match_searcher_main_lang <br>\nrecap.tweetfeature.match_ui_lang <br>\nrecap.tweetfeature.mention_searcher <br>\nrecap.tweetfeature.num_hashtags <br>\nrecap.tweetfeature.num_mentions <br>\nrecap.tweetfeature.prev_user_tweet_enagagement <br>\nrecap.tweetfeature.reply_other <br>\nrecap.tweetfeature.reply_searcher <br>\nrecap.tweetfeature.retweet_other <br>\nrecap.tweetfeature.retweet_searcher <br>\nrecap.tweetfeature.signature <br>\nrecap.tweetfeature.tweet_count_from_user_in_snapshot <br>\nrecap.tweetfeature.unidirectiona_fav_count <br>\nrecap.tweetfeature.unidirectional_reply_count <br>\nrecap.tweetfeature.unidirectional_retweet_count <br>\nrecap.tweetfeature.user_rep <br>\nrecap.tweetfeature.video_view_count <br>\nrecap.user_agent.client_name <br>\nrecap.user_agent.client_source <br>\nrecap.user_agent.client_version <br>\nrecap.user_agent.client_version_code <br>\nrecap.user_agent.device <br>\nrecap.user_agent.manufacturer <br>\nrecap.user_agent.network_connection <br>\nrecap.user_agent.sdk_version <br>\nrecap.v2.tweetfeature.is_retweet_directed_at_user_in_first_degree <br>\nrecap.v2.tweetfeature.is_retweet_of_reply <br>\nrecap.v2.tweetfeature.is_retweeter_bot <br>\nrecap.v2.tweetfeature.is_retweeter_new <br>\nrecap.v2.tweetfeature.is_retweeter_nsfw <br>\nrecap.v2.tweetfeature.is_retweeter_profile_egg <br>\nrecap.v2.tweetfeature.is_retweeter_spam <br>\nrecap.v2.tweetfeature.retweet_of_mutual_follow <br>\nrecap.v2.tweetfeature.source_author_rep <br>\nrecap.v3.tweetfeature.probably_from_follow\n</code>\n</td>\n</tr>\n</table>\n</details>\n<details>\n<summary><b><code>tweetsource</code></b></summary>\n<br>\nThis feature group contains features about the tweet media as well as conversation-related features about the tweet.\n<table>\n<tr>\n<td>\n<code>\n<br> \ntweetsource.tweet.media.aspect_ratio_den <br> \ntweetsource.tweet.media.aspect_ratio_num <br> \ntweetsource.tweet.media.bit_rate <br> \ntweetsource.tweet.media.height_1 <br> \ntweetsource.tweet.media.height_2 <br> \ntweetsource.tweet.media.height_3 <br> \ntweetsource.tweet.media.height_4 <br> \ntweetsource.tweet.media.num_tags <br> \ntweetsource.tweet.media.resize_method_1 <br> \ntweetsource.tweet.media.resize_method_2 <br> \ntweetsource.tweet.media.resize_method_3 <br> \ntweetsource.tweet.media.resize_method_4 <br> \ntweetsource.tweet.media.video_duration <br> \ntweetsource.tweet.media.width_1 <br> \ntweetsource.tweet.media.width_2 <br> \ntweetsource.tweet.media.width_3 <br> \ntweetsource.tweet.media.width_4 <br> \ntweetsource.tweet.text.has_question <br> \ntweetsource.tweet.text.length <br> \ntweetsource.tweet.text.length_type <br> \ntweetsource.tweet.text.num_caps <br> \ntweetsource.tweet.text.num_newlines <br> \ntweetsource.tweet.text.num_whitespaces <br> \ntweetsource.v2.tweet.media.color_1_blue <br> \ntweetsource.v2.tweet.media.color_1_green <br> \ntweetsource.v2.tweet.media.color_1_percentage <br> \ntweetsource.v2.tweet.media.color_1_red <br> \ntweetsource.v2.tweet.media.face_areas <br> \ntweetsource.v2.tweet.media.has_app_install_call_to_action <br> \ntweetsource.v2.tweet.media.has_description <br> \ntweetsource.v2.tweet.media.has_selected_preview_image <br> \ntweetsource.v2.tweet.media.has_title <br> \ntweetsource.v2.tweet.media.has_visit_site_call_to_action <br> \ntweetsource.v2.tweet.media.has_watch_now_call_to_action <br> \ntweetsource.v2.tweet.media.is_360 <br> \ntweetsource.v2.tweet.media.is_embeddable <br> \ntweetsource.v2.tweet.media.is_managed <br> \ntweetsource.v2.tweet.media.is_monetizable <br> \ntweetsource.v2.tweet.media.num_color_pallette_items <br> \ntweetsource.v2.tweet.media.num_faces <br> \ntweetsource.v2.tweet.media.num_stickers <br> \ntweetsource.v2.tweet.media.view_count <br> \n</td>\n</tr>\n</table>\n</code>\n</details>\n\n<details>\n<summary><b><code>in_reply_to_tweet</code></b></summary>\n<br>\nIf the tweet was a reply, this feature group contains the features of the replied to tweet.\n<table>\n<tr>\n<td>\n<code>\nin_reply_to_tweet.recap.earlybird.fav_count_v2 <br>\nin_reply_to_tweet.recap.earlybird.reply_count_v2 <br>\nin_reply_to_tweet.recap.earlybird.retweet_count_v2 <br>\nin_reply_to_tweet.recap.searchfeature.fav_count <br>\nin_reply_to_tweet.recap.searchfeature.reply_count <br>\nin_reply_to_tweet.recap.searchfeature.retweet_count <br>\nin_reply_to_tweet.recap.searchfeature.text_score <br>\nin_reply_to_tweet.recap.tweetfeature.bidirectional_fav_count <br>\nin_reply_to_tweet.recap.tweetfeature.bidirectional_reply_count <br>\nin_reply_to_tweet.recap.tweetfeature.bidirectional_retweet_count <br>\nin_reply_to_tweet.recap.tweetfeature.conversational_count <br>\nin_reply_to_tweet.recap.tweetfeature.from_mutual_follow <br>\nin_reply_to_tweet.recap.tweetfeature.from_verified_account <br>\nin_reply_to_tweet.recap.tweetfeature.has_hashtag <br>\nin_reply_to_tweet.recap.tweetfeature.has_image <br>\nin_reply_to_tweet.recap.tweetfeature.has_mention <br>\nin_reply_to_tweet.recap.tweetfeature.has_news <br>\nin_reply_to_tweet.recap.tweetfeature.has_video <br>\nin_reply_to_tweet.recap.tweetfeature.has_visible_link <br>\nin_reply_to_tweet.recap.tweetfeature.is_author_bot <br>\nin_reply_to_tweet.recap.tweetfeature.is_author_new <br>\nin_reply_to_tweet.recap.tweetfeature.is_author_nsfw <br>\nin_reply_to_tweet.recap.tweetfeature.is_author_spam <br>\nin_reply_to_tweet.recap.tweetfeature.is_offensive <br>\nin_reply_to_tweet.recap.tweetfeature.is_reply <br>\nin_reply_to_tweet.recap.tweetfeature.is_sensitive <br>\nin_reply_to_tweet.recap.tweetfeature.num_mentions <br>\nin_reply_to_tweet.recap.tweetfeature.prev_user_tweet_enagagement <br>\nin_reply_to_tweet.recap.tweetfeature.unidirectiona_fav_count <br>\nin_reply_to_tweet.recap.tweetfeature.unidirectional_reply_count <br>\nin_reply_to_tweet.recap.tweetfeature.unidirectional_retweet_count <br>\nin_reply_to_tweet.recap.tweetfeature.user_rep <br>\nin_reply_to_tweet.timelines.earlybird.decayed_favorite_count <br>\nin_reply_to_tweet.timelines.earlybird.decayed_quote_count <br>\nin_reply_to_tweet.timelines.earlybird.decayed_reply_count <br>\nin_reply_to_tweet.timelines.earlybird.decayed_retweet_count <br>\nin_reply_to_tweet.timelines.earlybird.has_quote <br>\nin_reply_to_tweet.timelines.earlybird.quote_count <br>\nin_reply_to_tweet.timelines.earlybird.weighted_fav_count <br>\nin_reply_to_tweet.timelines.earlybird.weighted_quote_count <br>\nin_reply_to_tweet.timelines.earlybird.weighted_reply_count <br>\nin_reply_to_tweet.timelines.earlybird.weighted_retweet_count <br>\nin_reply_to_tweet.timelines.earlybird_score <br>\nin_reply_to_tweet.tweetsource.tweet.media.aspect_ratio_den <br>\nin_reply_to_tweet.tweetsource.tweet.media.aspect_ratio_num <br>\nin_reply_to_tweet.tweetsource.tweet.media.height_1 <br>\nin_reply_to_tweet.tweetsource.tweet.media.height_2 <br>\nin_reply_to_tweet.tweetsource.tweet.media.video_duration <br>\nin_reply_to_tweet.tweetsource.tweet.text.has_question <br>\nin_reply_to_tweet.tweetsource.tweet.text.length <br>\nin_reply_to_tweet.tweetsource.tweet.text.num_caps <br>\n</code>\n</td>\n</tr>\n</table>\n</code>\n</details>\n\n<details>\n<summary><b><code>timelines.earlybird</code></b></summary>\n<br>\nThis feature group passes on features used by the search and light ranking service (\"Earlybird\") to the Heavy Ranker. <br>\n<table>\n<tr>\n<td>\n<code>\ntimelines.earlybird.decayed_favorite_count <br>\ntimelines.earlybird.decayed_quote_count <br>\ntimelines.earlybird.decayed_reply_count <br>\ntimelines.earlybird.decayed_retweet_count <br>\ntimelines.earlybird.embeds_impression_count_v2 <br>\ntimelines.earlybird.embeds_url_count_v2 <br>\ntimelines.earlybird.fake_favorite_count <br>\ntimelines.earlybird.fake_quote_count <br>\ntimelines.earlybird.fake_reply_count <br>\ntimelines.earlybird.fake_retweet_count <br>\ntimelines.earlybird.has_quote <br>\ntimelines.earlybird.is_composer_source_camera <br>\ntimelines.earlybird.label_abusive_flag <br>\ntimelines.earlybird.label_abusive_hi_rcl_flag <br>\ntimelines.earlybird.label_dup_content_flag <br>\ntimelines.earlybird.label_nsfw_hi_prc_flag <br>\ntimelines.earlybird.label_nsfw_hi_rcl_flag <br>\ntimelines.earlybird.label_spam_flag <br>\ntimelines.earlybird.label_spam_hi_rcl_flag <br>\ntimelines.earlybird.periscope_exists <br>\ntimelines.earlybird.periscope_has_been_featured <br>\ntimelines.earlybird.periscope_is_currently_featured <br>\ntimelines.earlybird.periscope_is_from_quality_source <br>\ntimelines.earlybird.periscope_is_live <br>\ntimelines.earlybird.preported_tweet_score <br>\ntimelines.earlybird.quote_count <br>\ntimelines.earlybird.visible_token_ratio <br>\ntimelines.earlybird.weighted_fav_count <br>\ntimelines.earlybird.weighted_quote_count <br>\ntimelines.earlybird.weighted_reply_count <br>\ntimelines.earlybird.weighted_retweet_count <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n<details>\n<summary><b><code>realtime_interaction_graph</code></b></summary>\n<br>\nUser-author interaction features. Similar to RealGraph but updated more rapidly. <br>\n<table>\n<tr>\n<td>\n<code>\nrealtime_interaction_graph.click.count <br>\nrealtime_interaction_graph.click.days_since_last <br>\nrealtime_interaction_graph.fav.count <br>\nrealtime_interaction_graph.fav.days_since_last <br>\nrealtime_interaction_graph.mention.count <br>\nrealtime_interaction_graph.mention.days_since_last <br>\nrealtime_interaction_graph.profile_view.count <br>\nrealtime_interaction_graph.profile_view.days_since_last <br>\nrealtime_interaction_graph.retweet.count <br>\nrealtime_interaction_graph.retweet.days_since_last <br>\nrealtime_interaction_graph.soft_follow.count <br>\nrealtime_interaction_graph.soft_follow.days_since_last\n</code>\n</td>\n</tr>\n</table>\n</details>\n<details>\n<summary><b><code>user_tweet.recommendations</code></b></summary>\n<br>\nSimilarity of a tweet to a user's recent engaged tweets. <br>\n<table>\n<tr>\n<td>\n<code>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.fav_1d_last_10_avg <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.fav_1d_last_10_max <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.fav_7d_last_10_avg <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.fav_7d_last_10_max <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.follow_30d_last_10_avg <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.follow_30d_last_10_max <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.follow_7d_last_10_avg <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.follow_7d_last_10_max <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.retweet_1d_last_10_avg <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.retweet_1d_last_10_max <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.retweet_7d_last_10_avg <br>\nuser_tweet.recommendations.sim_clusters_recent_engagement_similarity.retweet_7d_last_10_max <br>\nuser-tweet.recommendations.sim_clusters_scores.user_interested_in_tweet_embedding_dot_product_20m_145k_2020  <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n<details>\n<summary><b><code>other</code></b></summary>\n<br>\nHere we list individual features not covered in any feature group <br>\n<table>\n<tr>\n<td>\n<code>\nauthor_health.num_connect <br>\nauthor_health.num_connect_days <br>\nauthor_health.num_followers <br>\nengagement_features.in_network.favorites.count <br>\nengagement_features.in_network.replies.count <br>\nengagement_features.in_network.retweets.count <br>\nrequest_context.display_dpi <br>\nrequest_context.display_height <br>\nrequest_context.display_width <br>\nrequest_context.is_get_initial <br>\nrequest_context.is_get_middle <br>\nrequest_context.is_get_newer <br>\nrequest_context.is_get_older <br>\nrequest_context.is_session_start <br>\ntime_features.earlybird.last_favorite_since_creation_hrs <br>\ntime_features.earlybird.last_quote_since_creation_hrs <br>\ntime_features.earlybird.last_reply_since_creation_hrs <br>\ntime_features.earlybird.last_retweet_since_creation_hrs <br>\ntime_features.earlybird.time_since_last_favorite <br>\ntime_features.earlybird.time_since_last_quote <br>\ntime_features.earlybird.time_since_last_reply <br>\ntime_features.earlybird.time_since_last_retweet <br>\ntime_features.is_tweet_recycled <br>\ntime_features.non_polling_requests_since_tweet_creation <br>\ntime_features.time_between_non_polling_requests_avg <br>\ntime_features.time_since_last_non_polling_request <br>\ntime_features.time_since_source_tweet_creation <br>\ntime_features.time_since_tweet_creation <br>\ntime_features.time_since_viewer_account_creation_secs <br>\ntime_features.tweet_age_ratio <br>\n</code>\n</td>\n</tr>\n</table>\n</details>\n\n## Embeddings Features\n\n[Twhin](https://arxiv.org/pdf/2202.05387.pdf) is a large graph embedding trained on Twitter data. We use three 200-dimensional embeddings sourced from the Twhin algorithm.\n\n<details>\n<summary><b><code>Twhin Follow Embeddings</code></b></summary>\n<br>\nWe have two embeddings trained on the user-user follow graph, one representing who is likely to follow a user and the other representing who a user is likely to follow. Each embedding is 200-dimensional.\n</details>\n\n<details>\n<summary><b><code>Twhin Engagement Embeddings</code></b></summary>\n<br>\nWe have one embedding trained on the user-tweet engagement graph, representing users based on the Tweets they are likely to engage with. This embedding is 200 dimensional.\n"
  },
  {
    "path": "projects/home/recap/README.md",
    "content": "# Heavy Ranker\n\n## Overview\n\nThe heavy ranker is a machine learning model used to rank tweets for the \"For You\" timeline\nwhich have passed through the candidate retrieval stage. It is one of the final stages of the funnel, \nsucceeded primarily by a set of filtering heuristics.\n\nThe model receives features describing a Tweet and the user that the Tweet is being recommended to \n(see [FEATURES.md](./FEATURES.md)). The model architecture is a parallel [MaskNet](https://arxiv.org/abs/2102.07619) \nwhich outputs a set of numbers between 0 and 1, with each output representing the probability that the user \nwill engage with the tweet in a particular way. The predicted engagement types are explained below:\n```\nscored_tweets_model_weight_fav: The probability the user will favorite the Tweet.\nscored_tweets_model_weight_retweet: The probability the user will Retweet the Tweet.\nscored_tweets_model_weight_reply: The probability the user replies to the Tweet.\nscored_tweets_model_weight_good_profile_click: The probability the user opens the Tweet author profile and Likes or replies to a Tweet.\nscored_tweets_model_weight_video_playback50: The probability (for a video Tweet) that the user will watch at least half of the video.\nscored_tweets_model_weight_reply_engaged_by_author: The probability the user replies to the Tweet and this reply is engaged by the Tweet author.\nscored_tweets_model_weight_good_click: The probability the user will click into the conversation of this Tweet and reply or Like a Tweet.\nscored_tweets_model_weight_good_click_v2: The probability the user will click into the conversation of this Tweet and stay there for at least 2 minutes.\nscored_tweets_model_weight_negative_feedback_v2: The probability the user will react negatively (requesting \"show less often\" on the Tweet or author, block or mute the Tweet author).\nscored_tweets_model_weight_report: The probability the user will click Report Tweet.\n```\n\nThe outputs of the model are combined into a final model score by doing a weighted sum across the predicted engagement probabilities. \nThe weight of each engagement probability comes from a configuration file, read by the serving stack \n[here](https://github.com/twitter/the-algorithm/blob/main/home-mixer/server/src/main/scala/com/twitter/home_mixer/product/scored_tweets/param/ScoredTweetsParam.scala#L84). The exact weights in the file can be adjusted at any time, but the current weighting of probabilities \n(April 5, 2023) is as follows:\n```\nscored_tweets_model_weight_fav: 0.5\nscored_tweets_model_weight_retweet: 1.0\nscored_tweets_model_weight_reply: 13.5\nscored_tweets_model_weight_good_profile_click: 12.0\nscored_tweets_model_weight_video_playback50: 0.005\nscored_tweets_model_weight_reply_engaged_by_author: 75.0\nscored_tweets_model_weight_good_click: 11.0\nscored_tweets_model_weight_good_click_v2: 10.0\nscored_tweets_model_weight_negative_feedback_v2: -74.0\nscored_tweets_model_weight_report: -369.0\n```\n\nEssentially, the formula is:\n```\nscore = sum_i { (weight of engagement i) * (probability of engagement i) }\n```\n\nSince each engagement has a different average probability, the weights were originally set so that, \non average, each weighted engagement probability contributes a near-equal amount to the score. \nSince then, we have periodically adjusted the weights to optimize for platform metrics.\n\nSome disclaimers:\n- Due to the need to make sure this runs independently from other parts of Twitter codebase, there may be small differences from the production model.\n- We cannot release the real training data due to privacy restrictions. However, we have included a script to generate random data to ensure you can run the model training code.\n\n## Development\nAfter following the repo setup instructions, you can run the following script from a virtual environment to create a \nrandom training dataset in `$HOME/tmp/recap_local_random_data`:\n```sh\nprojects/home/recap/scripts/create_random_data.sh\n```\n\nYou can then train the model using the following script.\nCheckpoints and logs will be written to `$HOME/tmp/runs/recap_local_debug`:\n```sh\nprojects/home/recap/scripts/run_local.sh\n```\n\nThe model training can be configured in `projects/home/recap/config/local_prod.yaml`\n"
  },
  {
    "path": "projects/home/recap/__init__.py",
    "content": ""
  },
  {
    "path": "projects/home/recap/config/home_recap_2022/segdense.json",
    "content": "{\n  \"schema\": [\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"home_recap_2022_discrete__segdense_vals\",\n      \"length\": 320\n    },\n    {\n      \"dtype\": \"float_list\",\n      \"feature_name\": \"home_recap_2022_cont__segdense_vals\",\n      \"length\": 6000\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"home_recap_2022_binary__segdense_vals\",\n      \"length\": 512\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_tweet_detail_dwelled_15_sec\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_profile_clicked_and_profile_engaged\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_replied_reply_engaged_by_author\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_video_playback_50\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_report_tweet_clicked\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_replied\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"meta.author_id\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_negative_feedback_v2\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_retweeted\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_favorited\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"meta.tweet_id\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_good_clicked_convo_desc_v2\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"meta.user_id\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_bookmarked\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"recap.engagement.is_shared\",\n      \"length\": 1\n    },\n    {\n      \"dtype\": \"float_list\",\n      \"feature_name\": \"user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings\",\n      \"length\": 200\n    },\n    {\n      \"dtype\": \"float_list\",\n      \"feature_name\": \"original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings\",\n      \"length\": 200\n    },\n    {\n      \"dtype\": \"float_list\",\n      \"feature_name\": \"user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings\",\n      \"length\": 200\n    }\n  ]\n}"
  },
  {
    "path": "projects/home/recap/config/local_prod.yaml",
    "content": "training:\n  num_train_steps: 10\n  num_eval_steps: 5\n  checkpoint_every_n: 5\n  train_log_every_n: 1\n  eval_log_every_n: 1\n  save_dir: ${HOME}/tmp/runs/recap_local_debug\n  eval_timeout_in_s: 7200\nmodel:\n  backbone:\n    affine_map: null\n    dcn_config: null\n    dlrm_config: null\n    mask_net_config:\n      mask_blocks:\n        - aggregation_size: 1024\n          input_layer_norm: false\n          output_size: 1024\n          reduction_factor: null\n        - aggregation_size: 1024\n          input_layer_norm: false\n          output_size: 1024\n          reduction_factor: null\n        - aggregation_size: 1024\n          input_layer_norm: false\n          output_size: 1024\n          reduction_factor: null\n        - aggregation_size: 1024\n          input_layer_norm: false\n          output_size: 1024\n          reduction_factor: null\n      mlp:\n        batch_norm: null\n        dropout: null\n        final_layer_activation: true\n        layer_sizes:\n          - 2048\n      use_parallel: true\n    mlp_config: null\n    pos_weight: 1.0\n  featurization_config:\n    clip_log1p_abs_config: null\n    double_norm_log_config:\n      batch_norm_config:\n        affine: true\n        momentum: 0.01\n      clip_magnitude: 5.0\n      layer_norm_config:\n        axis: -1\n        center: true\n        epsilon: 0.0\n        scale: true\n    feature_names_to_concat:\n      - binary\n    log1p_abs_config: null\n    z_score_log_config: null\n  large_embeddings: null\n  multi_task_type: share_all\n  position_debias_config: null\n  small_embeddings: null\n  stratifiers: null\n  tasks:\n    recap.engagement.is_favorited:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_good_clicked_convo_desc_v2:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_negative_feedback_v2:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout:\n          rate: 0.1\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_profile_clicked_and_profile_engaged:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_replied:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_replied_reply_engaged_by_author:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_report_tweet_clicked:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout:\n          rate: 0.2\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_retweeted:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\n    recap.engagement.is_video_playback_50:\n      affine_map: null\n      dcn_config: null\n      dlrm_config: null\n      mask_net_config: null\n      mlp_config:\n        batch_norm:\n          affine: false\n          momentum: 0.1\n        dropout: null\n        final_layer_activation: false\n        layer_sizes:\n          - 256\n          - 128\n          - 1\n      pos_weight: 1.0\ntrain_data:\n  global_batch_size: 128\n  dataset_service_compression: AUTO\n  inputs: &data_root \"${HOME}/tmp/recap_local_random_data/*.gz\"\n  seg_dense_schema: &seg_dense_schema\n    schema_path: \"${TML_BASE}/projects/home/recap/config/home_recap_2022/segdense.json\"\n    renamed_features:\n      \"continuous\": \"home_recap_2022_cont__segdense_vals\"\n      \"binary\": \"home_recap_2022_binary__segdense_vals\"\n      \"discrete\": \"home_recap_2022_discrete__segdense_vals\"\n      \"author_embedding\": \"original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings\"\n      \"user_embedding\": \"user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings\"\n      \"user_eng_embedding\": \"user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings\"\n      \"meta__author_id\": \"meta.author_id\"\n      \"meta__user_id\": \"meta.user_id\"\n      \"meta__tweet_id\": \"meta.tweet_id\"\n  tasks: &data_tasks\n    \"recap.engagement.is_bookmarked\": {}\n    \"recap.engagement.is_favorited\": {}\n    \"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied\": {}\n    \"recap.engagement.is_good_clicked_convo_desc_v2\": {}\n    \"recap.engagement.is_negative_feedback_v2\": {}\n    \"recap.engagement.is_profile_clicked_and_profile_engaged\": {}\n    \"recap.engagement.is_replied\": {}\n    \"recap.engagement.is_replied_reply_engaged_by_author\": {}\n    \"recap.engagement.is_report_tweet_clicked\": {}\n    \"recap.engagement.is_retweeted\": {}\n    \"recap.engagement.is_shared\": {}\n    \"recap.engagement.is_tweet_detail_dwelled_15_sec\": {}\n    \"recap.engagement.is_video_playback_50\": {}\n  preprocess: &preprocess\n    truncate_and_slice:\n      continuous_feature_truncation: 2117\n      binary_feature_truncation: 59\nvalidation_data:\n  validation: &validation\n    global_batch_size: &eval_batch_size 128\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks: *data_tasks\n    preprocess: *preprocess\n  train:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks: *data_tasks\n    preprocess: *preprocess\n  recap.engagement.is_favorited:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_favorited\":\n        pos_downsampling_rate: 0.8387\n        neg_downsampling_rate: 0.01\n    evaluation_tasks:\n      - \"recap.engagement.is_favorited\"\n    preprocess: *preprocess\n  recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied\":\n        pos_downsampling_rate: 0.9164\n        neg_downsampling_rate: 0.00195\n    evaluation_tasks:\n      - \"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied\"\n    preprocess: *preprocess\n  recap.engagement.is_good_clicked_convo_desc_v2:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_good_clicked_convo_desc_v2\":\n        pos_downsampling_rate: 1.0\n        neg_downsampling_rate: 0.00174\n    evaluation_tasks:\n      - \"recap.engagement.is_good_clicked_convo_desc_v2\"\n    preprocess: *preprocess\n  recap.engagement.is_negative_feedback_v2:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_negative_feedback_v2\":\n        pos_downsampling_rate: 1.0\n        neg_downsampling_rate: 0.00280\n    evaluation_tasks:\n      - \"recap.engagement.is_negative_feedback_v2\"\n    preprocess: *preprocess\n  recap.engagement.is_profile_clicked_and_profile_engaged:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_profile_clicked_and_profile_engaged\":\n        pos_downsampling_rate: 1.0\n        neg_downsampling_rate: 0.0015\n    evaluation_tasks:\n      - \"recap.engagement.is_profile_clicked_and_profile_engaged\"\n    preprocess: *preprocess\n  recap.engagement.is_replied:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_replied\":\n        pos_downsampling_rate: 1.0\n        neg_downsampling_rate: 0.005\n    evaluation_tasks:\n      - \"recap.engagement.is_replied\"\n    preprocess: *preprocess\n  recap.engagement.is_replied_reply_engaged_by_author:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_replied_reply_engaged_by_author\":\n        pos_downsampling_rate: 1.0\n        neg_downsampling_rate: 0.001\n    evaluation_tasks:\n      - \"recap.engagement.is_replied_reply_engaged_by_author\"\n    preprocess: *preprocess\n  recap.engagement.is_report_tweet_clicked:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_report_tweet_clicked\":\n        pos_downsampling_rate: 1.0\n        neg_downsampling_rate: 0.000014\n    evaluation_tasks:\n      - \"recap.engagement.is_report_tweet_clicked\"\n    preprocess: *preprocess\n  recap.engagement.is_retweeted:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_retweeted\":\n        pos_downsampling_rate: 0.9561\n        neg_downsampling_rate: 0.004\n    evaluation_tasks:\n      - \"recap.engagement.is_retweeted\"\n    preprocess: *preprocess\n  recap.engagement.is_video_playback_50:\n    global_batch_size: *eval_batch_size\n    inputs: *data_root\n    seg_dense_schema: *seg_dense_schema\n    tasks:\n      <<: *data_tasks\n      \"recap.engagement.is_video_playback_50\":\n        pos_downsampling_rate: 1.0\n        neg_downsampling_rate: 0.00427\n    evaluation_tasks:\n      - \"recap.engagement.is_video_playback_50\"\n    preprocess: *preprocess\n\noptimizer:\n  adam:\n    beta_1: 0.95\n    beta_2: 0.999\n    epsilon: 1.0e-07\n  multi_task_learning_rates:\n    backbone_learning_rate:\n      constant: null\n      linear_ramp_to_constant:\n        learning_rate: 0.0001\n        num_ramp_steps: 1000\n      linear_ramp_to_cosine: null\n      piecewise_constant: null\n    tower_learning_rates:\n      recap.engagement.is_favorited:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0008\n          num_ramp_steps: 5000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0001\n          num_ramp_steps: 2000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_good_clicked_convo_desc_v2:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0002\n          num_ramp_steps: 1000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_negative_feedback_v2:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0005\n          num_ramp_steps: 5000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_profile_clicked_and_profile_engaged:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0003\n          num_ramp_steps: 1000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_replied:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.001\n          num_ramp_steps: 1000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_replied_reply_engaged_by_author:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0001\n          num_ramp_steps: 1000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_report_tweet_clicked:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0001\n          num_ramp_steps: 3000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_retweeted:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.0001\n          num_ramp_steps: 1000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n      recap.engagement.is_video_playback_50:\n        constant: null\n        linear_ramp_to_constant:\n          learning_rate: 0.003\n          num_ramp_steps: 1000\n        linear_ramp_to_cosine: null\n        piecewise_constant: null\n  single_task_learning_rate: null\n"
  },
  {
    "path": "projects/home/recap/config.py",
    "content": "from tml.core import config as config_mod\nimport tml.projects.home.recap.data.config as data_config\nimport tml.projects.home.recap.model.config as model_config\nimport tml.projects.home.recap.optimizer.config as optimizer_config\n\nfrom enum import Enum\nfrom typing import Dict, Optional\nimport pydantic\n\n\nclass TrainingConfig(config_mod.BaseConfig):\n  save_dir: str = \"/tmp/model\"\n  num_train_steps: pydantic.PositiveInt = 1000000\n  initial_checkpoint_dir: str = pydantic.Field(\n    None, description=\"Directory of initial checkpoints\", at_most_one_of=\"initialization\"\n  )\n  checkpoint_every_n: pydantic.PositiveInt = 1000\n  checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Maximum number of checkpoints to keep. Defaults to keeping all.\"\n  )\n  train_log_every_n: pydantic.PositiveInt = 1000\n  num_eval_steps: int = pydantic.Field(\n    16384, description=\"Number of evaluation steps. If < 0 the entire dataset \" \"will be used.\"\n  )\n  eval_log_every_n: pydantic.PositiveInt = 5000\n\n  eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60\n\n  gradient_accumulation: int = pydantic.Field(\n    None, description=\"Number of replica steps to accumulate gradients.\"\n  )\n\n\nclass RecapConfig(config_mod.BaseConfig):\n  training: TrainingConfig = pydantic.Field(TrainingConfig())\n  model: model_config.ModelConfig\n  train_data: data_config.RecapDataConfig\n  validation_data: Dict[str, data_config.RecapDataConfig]\n  optimizer: optimizer_config.RecapOptimizerConfig\n\n  which_metrics: Optional[str] = pydantic.Field(None, description=\"which metrics to pick.\")\n\n  # DANGER DANGER! You might expect validators here to ensure that multi task learning setups are\n  # the same as the data. Unfortunately, this throws opaque errors when the model configuration is\n  # invalid. In our judgement, that is a more frequency and worse occurrence than tasks not matching\n  # the data.\n\n\nclass JobMode(str, Enum):\n  \"\"\"Job modes.\"\"\"\n\n  TRAIN = \"train\"\n  EVALUATE = \"evaluate\"\n  INFERENCE = \"inference\"\n"
  },
  {
    "path": "projects/home/recap/data/__init__.py",
    "content": ""
  },
  {
    "path": "projects/home/recap/data/config.py",
    "content": "import typing\nfrom enum import Enum\n\n\nfrom tml.core import config as base_config\n\nimport pydantic\n\n\nclass ExplicitDateInputs(base_config.BaseConfig):\n  \"\"\"Arguments to select train/validation data using end_date and days of data.\"\"\"\n\n  data_root: str = pydantic.Field(..., description=\"Data path prefix.\")\n  end_date: str = pydantic.Field(..., description=\"Data end date, inclusive.\")\n  days: int = pydantic.Field(..., description=\"Number of days of data for dataset.\")\n  num_missing_days_tol: int = pydantic.Field(\n    0, description=\"We tolerate <= num_missing_days_tol days of missing data.\"\n  )\n\n\nclass ExplicitDatetimeInputs(base_config.BaseConfig):\n  \"\"\"Arguments to select train/validation data using end_datetime and hours of data.\"\"\"\n\n  data_root: str = pydantic.Field(..., description=\"Data path prefix.\")\n  end_datetime: str = pydantic.Field(..., description=\"Data end datetime, inclusive.\")\n  hours: int = pydantic.Field(..., description=\"Number of hours of data for dataset.\")\n  num_missing_hours_tol: int = pydantic.Field(\n    0, description=\"We tolerate <= num_missing_hours_tol hours of missing data.\"\n  )\n\n\nclass DdsCompressionOption(str, Enum):\n  \"\"\"The only valid compression option is 'AUTO'\"\"\"\n\n  AUTO = \"AUTO\"\n\n\nclass DatasetConfig(base_config.BaseConfig):\n  inputs: str = pydantic.Field(\n    None, description=\"A glob for selecting data.\", one_of=\"date_inputs_format\"\n  )\n  explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(\n    None, one_of=\"date_inputs_format\"\n  )\n  explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of=\"date_inputs_format\")\n\n  global_batch_size: pydantic.PositiveInt\n\n  num_files_to_keep: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Number of shards to keep.\"\n  )\n  repeat_files: bool = pydantic.Field(\n    True, description=\"DEPRICATED. Files are repeated no matter what this is set to.\"\n  )\n  file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description=\"File batch size\")\n\n  cache: bool = pydantic.Field(\n    False,\n    description=\"Cache dataset in memory. Careful to only use this when you\"\n    \" have enough memory to fit entire dataset.\",\n  )\n\n  data_service_dispatcher: str = pydantic.Field(None)\n  ignore_data_errors: bool = pydantic.Field(\n    False, description=\"Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs.\"\n  )\n  dataset_service_compression: DdsCompressionOption = pydantic.Field(\n    None,\n    description=\"Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'\",\n  )\n\n  # tf.data.Dataset options\n  examples_shuffle_buffer_size: int = pydantic.Field(1024, description=\"Size of shuffle buffers.\")\n  map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Number of parallel calls.\"\n  )\n  interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Number of shards to interleave.\"\n  )\n\n\nclass TruncateAndSlice(base_config.BaseConfig):\n  # Apply truncation and then slice.\n  continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Experimental. Truncates continuous features to this amount for efficiency.\"\n  )\n  binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Experimental. Truncates binary features to this amount for efficiency.\"\n  )\n\n  continuous_feature_mask_path: str = pydantic.Field(\n    None, description=\"Path of mask used to slice input continuous features.\"\n  )\n  binary_feature_mask_path: str = pydantic.Field(\n    None, description=\"Path of mask used to slice input binary features.\"\n  )\n\n\nclass DataType(str, Enum):\n  BFLOAT16 = \"bfloat16\"\n  BOOL = \"bool\"\n\n  FLOAT32 = \"float32\"\n  FLOAT16 = \"float16\"\n\n  UINT8 = \"uint8\"\n\n\nclass DownCast(base_config.BaseConfig):\n  # Apply down casting to selected features.\n  features: typing.Dict[str, DataType] = pydantic.Field(\n    None, description=\"Map features to down cast data types.\"\n  )\n\n\nclass TaskData(base_config.BaseConfig):\n  pos_downsampling_rate: float = pydantic.Field(\n    1.0,\n    description=\"Downsampling rate of positives used to generate dataset.\",\n  )\n  neg_downsampling_rate: float = pydantic.Field(\n    1.0,\n    description=\"Downsampling rate of negatives used to generate dataset.\",\n  )\n\n\nclass SegDenseSchema(base_config.BaseConfig):\n  schema_path: str = pydantic.Field(..., description=\"Path to feature config json.\")\n  features: typing.List[str] = pydantic.Field(\n    [],\n    description=\"List of features (in addition to the renamed features) to read from schema path above.\",\n  )\n  renamed_features: typing.Dict[str, str] = pydantic.Field(\n    {}, description=\"Dictionary of renamed features.\"\n  )\n  mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(\n    {},\n    description=\"(experimental) Number of mantissa bits to mask to simulate lower precision data.\",\n  )\n\n\nclass RectifyLabels(base_config.BaseConfig):\n  label_rectification_window_in_hours: float = pydantic.Field(\n    3.0, description=\"overlap time in hours for which to flip labels\"\n  )\n  served_timestamp_field: str = pydantic.Field(\n    ..., description=\"input field corresponding to served time\"\n  )\n  impressed_timestamp_field: str = pydantic.Field(\n    ..., description=\"input field corresponding to impressed time\"\n  )\n  label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(\n    ..., description=\"label to the input field corresponding to engagement time\"\n  )\n\n\nclass ExtractFeaturesRow(base_config.BaseConfig):\n  name: str = pydantic.Field(\n    ...,\n    description=\"name of the new field name to be created\",\n  )\n  source_tensor: str = pydantic.Field(\n    ...,\n    description=\"name of the dense tensor to look for the feature\",\n  )\n  index: int = pydantic.Field(\n    ...,\n    description=\"index of the feature in the dense tensor\",\n  )\n\n\nclass ExtractFeatures(base_config.BaseConfig):\n  extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(\n    [],\n    description=\"list of features to be extracted with their name, source tensor and index\",\n  )\n\n\nclass DownsampleNegatives(base_config.BaseConfig):\n  batch_multiplier: int = pydantic.Field(\n    None,\n    description=\"batch multiplier\",\n  )\n  engagements_list: typing.List[str] = pydantic.Field(\n    [],\n    description=\"engagements with kept positives\",\n  )\n  num_engagements: int = pydantic.Field(\n    ...,\n    description=\"number engagements used in the model, including ones excluded in engagements_list\",\n  )\n\n\nclass Preprocess(base_config.BaseConfig):\n  truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description=\"Truncation and slicing.\")\n  downcast: DownCast = pydantic.Field(None, description=\"Down cast to features.\")\n  rectify_labels: RectifyLabels = pydantic.Field(\n    None, description=\"Rectify labels for a given overlap window\"\n  )\n  extract_features: ExtractFeatures = pydantic.Field(\n    None, description=\"Extract features from dense tensors.\"\n  )\n  downsample_negatives: DownsampleNegatives = pydantic.Field(\n    None, description=\"Downsample negatives.\"\n  )\n\n\nclass Sampler(base_config.BaseConfig):\n  \"\"\"Assumes function is defined in data/samplers.py.\n\n  Only use this for quick experimentation.\n  If samplers are useful, we should sample from upstream data generation.\n\n  DEPRICATED, DO NOT USE.\n  \"\"\"\n\n  name: str\n  kwargs: typing.Dict\n\n\nclass RecapDataConfig(DatasetConfig):\n  seg_dense_schema: SegDenseSchema\n\n  tasks: typing.Dict[str, TaskData] = pydantic.Field(\n    description=\"Description of individual tasks in this dataset.\"\n  )\n  evaluation_tasks: typing.List[str] = pydantic.Field(\n    [], description=\"If specified, lists the tasks we're generating metrics for.\"\n  )\n\n  preprocess: Preprocess = pydantic.Field(\n    None, description=\"Function run in tf.data.Dataset at train/eval, in-graph at inference.\"\n  )\n\n  sampler: Sampler = pydantic.Field(\n    None,\n    description=\"\"\"DEPRICATED, DO NOT USE. Sampling function for offline experiments.\"\"\",\n  )\n\n  @pydantic.root_validator()\n  def _validate_evaluation_tasks(cls, values):\n    if values.get(\"evaluation_tasks\") is not None:\n      for task in values[\"evaluation_tasks\"]:\n        if task not in values[\"tasks\"]:\n          raise KeyError(f\"Evaluation task {task} must be in tasks. Received {values['tasks']}\")\n    return values\n"
  },
  {
    "path": "projects/home/recap/data/dataset.py",
    "content": "from dataclasses import dataclass\nfrom typing import Callable, List, Optional, Tuple, Dict\nimport functools\n\nimport torch\nimport tensorflow as tf\n\nfrom tml.common.batch import DataclassBatch\nfrom tml.projects.home.recap.data.config import RecapDataConfig, TaskData\nfrom tml.projects.home.recap.data import preprocessors\nfrom tml.projects.home.recap.config import JobMode\nfrom tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn\nfrom tml.projects.home.recap.data.util import (\n  keyed_jagged_tensor_from_tensors_dict,\n  sparse_or_dense_tf_to_torch,\n)\nfrom absl import logging\nimport torch.distributed as dist\n\n\n@dataclass\nclass RecapBatch(DataclassBatch):\n  \"\"\"Holds features and labels from the Recap dataset.\"\"\"\n\n  continuous_features: torch.Tensor\n  binary_features: torch.Tensor\n  discrete_features: torch.Tensor\n  sparse_features: \"KeyedJaggedTensor\"  # type: ignore[name-defined]  # noqa: F821\n  labels: torch.Tensor\n  user_embedding: torch.Tensor = None\n  user_eng_embedding: torch.Tensor = None\n  author_embedding: torch.Tensor = None\n  weights: torch.Tensor = None\n\n  def __post_init__(self):\n    if self.weights is None:\n      self.weights = torch.ones_like(self.labels)\n    for feature_name, feature_value in self.as_dict().items():\n      if (\"embedding\" in feature_name) and (feature_value is None):\n        setattr(self, feature_name, torch.empty([0, 0]))\n\n\ndef to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:\n  \"\"\"Converts a torch data loader output into `RecapBatch`.\"\"\"\n\n  x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)\n  try:\n    features_in, labels = x\n  except ValueError:\n    # For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple\n    features_in, labels = x, None\n\n  sparse_features = keyed_jagged_tensor_from_tensors_dict({})\n  if sparse_feature_names:\n    sparse_features = keyed_jagged_tensor_from_tensors_dict(\n      {embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}\n    )\n\n  user_embedding, user_eng_embedding, author_embedding = None, None, None\n  if \"user_embedding\" in features_in:\n    if sparse_feature_names and \"meta__user_id\" in sparse_feature_names:\n      raise ValueError(\"Only one source of embedding for user is supported\")\n    else:\n      user_embedding = features_in[\"user_embedding\"]\n\n  if \"user_eng_embedding\" in features_in:\n    if sparse_feature_names and \"meta__user_eng_id\" in sparse_feature_names:\n      raise ValueError(\"Only one source of embedding for user is supported\")\n    else:\n      user_eng_embedding = features_in[\"user_eng_embedding\"]\n\n  if \"author_embedding\" in features_in:\n    if sparse_feature_names and \"meta__author_id\" in sparse_feature_names:\n      raise ValueError(\"Only one source of embedding for user is supported\")\n    else:\n      author_embedding = features_in[\"author_embedding\"]\n\n  return RecapBatch(\n    continuous_features=features_in[\"continuous\"],\n    binary_features=features_in[\"binary\"],\n    discrete_features=features_in[\"discrete\"],\n    sparse_features=sparse_features,\n    user_embedding=user_embedding,\n    user_eng_embedding=user_eng_embedding,\n    author_embedding=author_embedding,\n    labels=labels,\n    weights=features_in.get(\"weights\", None),  # Defaults to torch.ones_like(labels)\n  )\n\n\ndef _chain(param, f1, f2):\n  \"\"\"\n  Reduce multiple functions into one chained function\n  _chain(x, f1, f2) -> f2(f1(x))\n  \"\"\"\n  output = param\n  fns = [f1, f2]\n  for f in fns:\n    output = f(output)\n  return output\n\n\ndef _add_weights(inputs, tasks: Dict[str, TaskData]):\n  \"\"\"Adds weights based on label sampling for positive and negatives.\n\n  This is useful for numeric calibration etc. This mutates inputs.\n\n  Args:\n    inputs: A dictionary of strings to tensor-like structures.\n    tasks: A dict of string (label) to `TaskData` specifying inputs.\n\n  Returns:\n    A tuple of features and labels; weights are added to features.\n  \"\"\"\n\n  weights = []\n  for key, task in tasks.items():\n    label = inputs[key]\n    float_label = tf.cast(label, tf.float32)\n\n    weights.append(\n      float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate\n    )\n\n  # Ensure we are batch-major (assumes we batch before this call).\n  inputs[\"weights\"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)\n  return inputs\n\n\ndef get_datetimes(explicit_datetime_inputs):\n  \"\"\"Compute list datetime strings for train/validation data.\"\"\"\n  datetime_format = \"%Y/%m/%d/%H\"\n  end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)\n  dates = sorted(\n    [\n      (end - timedelta(hours=i + 1)).strftime(datetime_format)\n      for i in range(int(explicit_datetime_inputs.hours))\n    ]\n  )\n  return dates\n\n\ndef get_explicit_datetime_inputs_files(explicit_datetime_inputs):\n  \"\"\"\n  Compile list of files for training/validation.\n\n  Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.\n  For each hour of data, if the directory is missing or empty, we increment a counter to keep\n  track of the number of missing data hours.\n  Returns only files with a `.gz` extension.\n\n  Args:\n    explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object\n\n  Returns:\n    data_files: Sorted list of files to read corresponding to data at the desired datetimes\n    num_hours_missing: Number of hours that we are missing data\n\n  \"\"\"\n  datetimes = get_datetimes(explicit_datetime_inputs)\n  folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]\n  data_files = []\n  num_hours_missing = 0\n  for folder in folders:\n    try:\n      files = tf.io.gfile.listdir(folder)\n      if not files:\n        logging.warning(f\"{folder} contained no data files\")\n        num_hours_missing += 1\n      data_files.extend(\n        [\n          os.path.join(folder, filename)\n          for filename in files\n          if filename.rsplit(\".\", 1)[-1].lower() == \"gz\"\n        ]\n      )\n    except tf.errors.NotFoundError as e:\n      num_hours_missing += 1\n      logging.warning(f\"Cannot find directory {folder}. Missing one hour of data. Error: \\n {e}\")\n  return sorted(data_files), num_hours_missing\n\n\ndef _map_output_for_inference(\n  inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False\n):\n  if preprocessor:\n    raise ValueError(\"No preprocessor should be used at inference time.\")\n  if add_weights:\n    raise NotImplementedError()\n\n  # Add zero weights.\n  inputs[\"weights\"] = tf.zeros_like(tf.expand_dims(inputs[\"continuous\"][:, 0], -1))\n  for label in tasks:\n    del inputs[label]\n  return inputs\n\n\ndef _map_output_for_train_eval(\n  inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False\n):\n  if add_weights:\n    inputs = _add_weights_based_on_sampling_rates(inputs, tasks)\n\n  # Warning this has to happen first as it changes the input\n  if preprocessor:\n    inputs = preprocessor(inputs)\n\n  label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])\n\n  for label in tasks:\n    del inputs[label]\n\n  return inputs, label_values\n\n\ndef _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):\n  \"\"\"Adds weights based on label sampling for positive and negatives.\n\n  This is useful for numeric calibration etc. This mutates inputs.\n\n  Args:\n    inputs: A dictionary of strings to tensor-like structures.\n    tasks: A dict of string (label) to `TaskData` specifying inputs.\n\n  Returns:\n    A tuple of features and labels; weights are added to features.\n  \"\"\"\n  weights = []\n  for key, task in tasks.items():\n    label = inputs[key]\n    float_label = tf.cast(label, tf.float32)\n\n    weights.append(\n      float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate\n    )\n\n  # Ensure we are batch-major (assumes we batch before this call).\n  inputs[\"weights\"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)\n  return inputs\n\n\nclass RecapDataset(torch.utils.data.IterableDataset):\n  def __init__(\n    self,\n    data_config: RecapDataConfig,\n    dataset_service: Optional[str] = None,\n    mode: JobMode = JobMode.TRAIN,\n    compression: Optional[str] = \"AUTO\",\n    repeat: bool = False,\n    vocab_mapper: tf.keras.Model = None,\n  ):\n    logging.info(\"***** Labels *****\")\n    logging.info(list(data_config.tasks.keys()))\n\n    self._data_config = data_config\n    self._parse_fn = get_seg_dense_parse_fn(data_config)\n    self._mode = mode\n    self._repeat = repeat\n    self._num_concurrent_iterators = 1\n    self._vocab_mapper = vocab_mapper\n    self.dataset_service = dataset_service\n\n    preprocessor = None\n    self._batch_size_multiplier = 1\n    if data_config.preprocess:\n      preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)\n      if data_config.preprocess.downsample_negatives:\n        self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier\n\n    self._preprocessor = preprocessor\n\n    if mode == JobMode.INFERENCE:\n      if preprocessor is not None:\n        raise ValueError(\"Expect no preprocessor at inference time.\")\n      should_add_weights = False\n      output_map_fn = _map_output_for_inference  # (features,)\n    else:\n      # Only add weights if there is a reason to! If all weights will\n      # be equal to 1.0, save bandwidth between DDS and Chief by simply\n      # relying on the fact that weights default to 1.0 in `RecapBatch`\n      # WARNING: Weights may still be added as a side effect of a preprocessor\n      #          such as `DownsampleNegatives`.\n      should_add_weights = any(\n        [\n          task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0\n          for task_cfg in data_config.tasks.values()\n        ]\n      )\n      output_map_fn = _map_output_for_train_eval  # (features, labels)\n\n    self._output_map_fn = functools.partial(\n      output_map_fn,\n      tasks=data_config.tasks,\n      preprocessor=preprocessor,\n      add_weights=should_add_weights,\n    )\n\n    sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None\n\n    self._tf_dataset = self._create_tf_dataset()\n\n    self._init_tensor_spec()\n\n  def _init_tensor_spec(self):\n    def _tensor_spec_to_torch_shape(spec):\n      if spec.shape is None:\n        return None\n      shape = [x if x is not None else -1 for x in spec.shape]\n      return torch.Size(shape)\n\n    self.torch_element_spec = tf.nest.map_structure(\n      _tensor_spec_to_torch_shape, self._tf_dataset.element_spec\n    )\n\n  def _create_tf_dataset(self):\n    if hasattr(self, \"_tf_dataset\"):\n      raise ValueError(\"Do not call `_create_tf_dataset` more than once.\")\n\n    world_size = dist.get_world_size() if dist.is_initialized() else 1\n    per_replica_bsz = (\n      self._batch_size_multiplier * self._data_config.global_batch_size // world_size\n    )\n\n    dataset: tf.data.Dataset = self._create_base_tf_dataset(\n      batch_size=per_replica_bsz,\n    )\n\n    if self._repeat:\n      logging.info(\"Repeating dataset\")\n      dataset = dataset.repeat()\n\n    if self.dataset_service:\n      if self._num_concurrent_iterators > 1:\n        if not self.machines_config:\n          raise ValueError(\n            \"Must supply a machine_config for autotuning in order to use >1 concurrent iterators\"\n          )\n        dataset = dataset_lib.with_auto_tune_budget(\n          dataset,\n          machine_config=self.machines_config.chief,\n          num_concurrent_iterators=self.num_concurrent_iterators,\n          on_chief=False,\n        )\n\n      self.dataset_id, self.job_name = register_dataset(\n        dataset=dataset, dataset_service=self.dataset_service, compression=self.compression\n      )\n      dataset = distribute_from_dataset_id(\n        dataset_id=self.dataset_id,  # type: ignore[arg-type]\n        job_name=self.job_name,\n        dataset_service=self.dataset_service,\n        compression=self.compression,\n      )\n\n    elif self._num_concurrent_iterators > 1:\n      if not self.machines_config:\n        raise ValueError(\n          \"Must supply a machine_config for autotuning in order to use >1 concurrent iterators\"\n        )\n      dataset = dataset_lib.with_auto_tune_budget(\n        dataset,\n        machine_config=self.machines_config.chief,\n        num_concurrent_iterators=self._num_concurrent_iterators,\n        on_chief=True,\n      )\n\n    # Vocabulary mapping happens on the training node, not in dds because of size.\n    if self._vocab_mapper:\n      dataset = dataset.map(self._vocab_mapper)\n\n    return dataset.prefetch(world_size * 2)\n\n  def _create_base_tf_dataset(self, batch_size: int):\n    if self._data_config.inputs:\n      glob = self._data_config.inputs\n      filenames = sorted(tf.io.gfile.glob(glob))\n    elif self._data_config.explicit_datetime_inputs:\n      num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol\n      filenames, num_hours_missing = get_explicit_datetime_inputs_files(\n        self._data_config.explicit_datetime_inputs,\n        increment=\"hourly\",\n      )\n      if num_hours_missing > num_missing_hours_tol:\n        raise ValueError(\n          f\"We are missing {num_hours_missing} hours of data\"\n          f\"more than tolerance {num_missing_hours_tol}.\"\n        )\n    elif self._data_config.explicit_date_inputs:\n      num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol\n      filenames, num_days_missing = get_explicit_datetime_inputs_files(\n        self._data_config.explicit_date_inputs,\n        increment=\"daily\",\n      )\n      if num_days_missing > num_missing_days_tol:\n        raise ValueError(\n          f\"We are missing {num_days_missing} days of data\"\n          f\"more than tolerance {num_missing_days_tol}.\"\n        )\n    else:\n      raise ValueError(\n        \"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config\"\n      )\n\n    num_files = len(filenames)\n    logging.info(f\"Found {num_files} data files\")\n    if num_files < 1:\n      raise ValueError(\"No data files found\")\n\n    if self._data_config.num_files_to_keep is not None:\n      filenames = filenames[: self._data_config.num_files_to_keep]\n      logging.info(f\"Retaining only {len(filenames)} files.\")\n\n    filenames_ds = (\n      tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))\n      # Because of drop_remainder, if our dataset does not fill\n      # up a batch, it will emit nothing without this repeat.\n      .repeat(-1)\n    )\n\n    if self._data_config.file_batch_size:\n      filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)\n\n    def per_shard_dataset(filename):\n      ds = tf.data.TFRecordDataset([filename], compression_type=\"GZIP\")\n      return ds.prefetch(4)\n\n    ds = filenames_ds.interleave(\n      per_shard_dataset,\n      block_length=4,\n      deterministic=False,\n      num_parallel_calls=self._data_config.interleave_num_parallel_calls\n      or tf.data.experimental.AUTOTUNE,\n    )\n\n    # Combine functions into one map call to reduce overhead.\n    map_fn = functools.partial(\n      _chain,\n      f1=self._parse_fn,\n      f2=self._output_map_fn,\n    )\n\n    # Shuffle -> Batch -> Parse is the correct ordering\n    # Shuffling needs to be performed before batching otherwise there is not much point\n    # Batching happens before parsing because tf.Example parsing is actually vectorized\n    #     and works much faster overall on batches of data.\n    ds = (\n      # DANGER DANGER: there is a default shuffle size here.\n      ds.shuffle(self._data_config.examples_shuffle_buffer_size)\n      .batch(batch_size=batch_size, drop_remainder=True)\n      .map(\n        map_fn,\n        num_parallel_calls=self._data_config.map_num_parallel_calls\n        or tf.data.experimental.AUTOTUNE,\n      )\n    )\n\n    if self._data_config.cache:\n      ds = ds.cache()\n\n    if self._data_config.ignore_data_errors:\n      ds = ds.apply(tf.data.experimental.ignore_errors())\n\n    options = tf.data.Options()\n    options.experimental_deterministic = False\n    ds = ds.with_options(options)\n\n    return ds\n\n  def _gen(self):\n    for x in self._tf_dataset:\n      yield to_batch(x)\n\n  def to_dataloader(self) -> Dict[str, torch.Tensor]:\n    return torch.utils.data.DataLoader(self, batch_size=None)\n\n  def __iter__(self):\n    return iter(self._gen())\n"
  },
  {
    "path": "projects/home/recap/data/generate_random_data.py",
    "content": "import os\nimport json\nfrom absl import app, flags, logging\nimport tensorflow as tf\nfrom typing import Dict\n\nfrom tml.projects.home.recap.data import tfe_parsing\nfrom tml.core import config as tml_config_mod\nimport tml.projects.home.recap.config as recap_config_mod\n\nflags.DEFINE_string(\"config_path\", None, \"Path to hyperparameters for model.\")\nflags.DEFINE_integer(\"n_examples\", 100, \"Numer of examples to generate.\")\n\nFLAGS = flags.FLAGS\n\n\ndef _generate_random_example(\n  tf_example_schema: Dict[str, tf.io.FixedLenFeature]\n) -> Dict[str, tf.Tensor]:\n  example = {}\n  for feature_name, feature_spec in tf_example_schema.items():\n    dtype = feature_spec.dtype\n    if (dtype == tf.int64) or (dtype == tf.int32):\n      x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype)\n    elif (dtype == tf.float32) or (dtype == tf.float64):\n      x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype)\n    else:\n      raise NotImplementedError(f\"Unknown type {dtype}\")\n\n    example[feature_name] = x\n\n  return example\n\n\ndef _float_feature(value):\n  return tf.train.Feature(float_list=tf.train.FloatList(value=value))\n\n\ndef _int64_feature(value):\n  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))\n\n\ndef _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:\n  feature = {}\n  serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}\n  for feature_name, tensor in x.items():\n    feature[feature_name] = serializers[tensor.dtype](tensor)\n\n  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))\n  return example_proto.SerializeToString()\n\n\ndef generate_data(data_path: str, config: recap_config_mod.RecapConfig):\n  with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, \"r\") as f:\n    seg_dense_schema = json.load(f)[\"schema\"]\n\n  tf_example_schema = tfe_parsing.create_tf_example_schema(\n    config.train_data,\n    seg_dense_schema,\n  )\n\n  record_filename = os.path.join(data_path, \"random.tfrecord.gz\")\n\n  with tf.io.TFRecordWriter(record_filename, \"GZIP\") as writer:\n    random_example = _generate_random_example(tf_example_schema)\n    serialized_example = _serialize_example(random_example)\n    writer.write(serialized_example)\n\n\ndef _generate_data_main(unused_argv):\n  config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)\n\n  # Find the path where to put the data\n  data_path = os.path.dirname(config.train_data.inputs)\n  logging.info(\"Putting random data in %s\", data_path)\n\n  generate_data(data_path, config)\n\n\nif __name__ == \"__main__\":\n  app.run(_generate_data_main)\n"
  },
  {
    "path": "projects/home/recap/data/preprocessors.py",
    "content": "\"\"\"\nPreprocessors applied on DDS workers in order to modify the dataset on the fly.\nSome of these preprocessors are also applied to the model at serving time.\n\"\"\"\nfrom tml.projects.home.recap import config as config_mod\nfrom absl import logging\nimport tensorflow as tf\nimport numpy as np\n\n\nclass TruncateAndSlice(tf.keras.Model):\n  \"\"\"Class for truncating and slicing.\"\"\"\n\n  def __init__(self, truncate_and_slice_config):\n    super().__init__()\n    self._truncate_and_slice_config = truncate_and_slice_config\n\n    if self._truncate_and_slice_config.continuous_feature_mask_path:\n      with tf.io.gfile.GFile(\n        self._truncate_and_slice_config.continuous_feature_mask_path, \"rb\"\n      ) as f:\n        self._continuous_mask = np.load(f).nonzero()[0]\n      logging.info(f\"Slicing {np.sum(self._continuous_mask)} continuous features.\")\n    else:\n      self._continuous_mask = None\n\n    if self._truncate_and_slice_config.binary_feature_mask_path:\n      with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, \"rb\") as f:\n        self._binary_mask = np.load(f).nonzero()[0]\n      logging.info(f\"Slicing {np.sum(self._binary_mask)} binary features.\")\n    else:\n      self._binary_mask = None\n\n  def call(self, inputs, training=None, mask=None):\n    outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))\n    if self._truncate_and_slice_config.continuous_feature_truncation:\n      logging.info(\"Truncating continuous\")\n      outputs[\"continuous\"] = outputs[\"continuous\"][\n        :, : self._truncate_and_slice_config.continuous_feature_truncation\n      ]\n    if self._truncate_and_slice_config.binary_feature_truncation:\n      logging.info(\"Truncating binary\")\n      outputs[\"binary\"] = outputs[\"binary\"][\n        :, : self._truncate_and_slice_config.binary_feature_truncation\n      ]\n    if self._continuous_mask is not None:\n      outputs[\"continuous\"] = tf.gather(outputs[\"continuous\"], self._continuous_mask, axis=1)\n    if self._binary_mask is not None:\n      outputs[\"binary\"] = tf.gather(outputs[\"binary\"], self._binary_mask, axis=1)\n    return outputs\n\n\nclass DownCast(tf.keras.Model):\n  \"\"\"Class for Down casting dataset before serialization and transferring to training host.\n  Depends on the data type and the actual data range, the down casting can be lossless or not.\n  It is strongly recommended to compare the metrics before and after down casting.\n  \"\"\"\n\n  def __init__(self, downcast_config):\n    super().__init__()\n    self.config = downcast_config\n    self._type_map = {\n      \"bfloat16\": tf.bfloat16,\n      \"bool\": tf.bool,\n    }\n\n  def call(self, inputs, training=None, mask=None):\n    outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))\n    for feature, type_str in self.config.features.items():\n      assert type_str in self._type_map\n      if type_str == \"bfloat16\":\n        logging.warning(\n          \"Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics.\"\n        )\n      down_cast_data_type = self._type_map[type_str]\n      outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type)\n    return outputs\n\n\nclass RectifyLabels(tf.keras.Model):\n  \"\"\"Class for rectifying labels\"\"\"\n\n  def __init__(self, rectify_label_config):\n    super().__init__()\n    self._config = rectify_label_config\n    self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)\n\n  def call(self, inputs, training=None, mask=None):\n    served_ts_field = self._config.served_timestamp_field\n    impressed_ts_field = self._config.impressed_timestamp_field\n\n    for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items():\n      impressed = inputs[impressed_ts_field]\n      served = inputs[served_ts_field]\n      engaged = inputs[engaged_ts_field]\n\n      keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window)\n      keep = tf.math.logical_and(keep, engaged - served < self._window)\n      inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label]))\n\n    return inputs\n\n\nclass ExtractFeatures(tf.keras.Model):\n  \"\"\"Class for extracting individual features from dense tensors by their index.\"\"\"\n\n  def __init__(self, extract_features_config):\n    super().__init__()\n    self._config = extract_features_config\n\n  def call(self, inputs, training=None, mask=None):\n\n    for row in self._config.extract_feature_table:\n      inputs[row.name] = inputs[row.source_tensor][:, row.index]\n\n    return inputs\n\n\nclass DownsampleNegatives(tf.keras.Model):\n  \"\"\"Class for down-sampling/dropping negatives and updating the weights.\n\n  If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0]\n  inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0]\n  when batch_multiplier=2 and engagements_list=['fav']\n\n  It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't\n  drop positives for any engagement.\n  \"\"\"\n\n  def __init__(self, downsample_negatives_config):\n    super().__init__()\n    self.config = downsample_negatives_config\n\n  def call(self, inputs, training=None, mask=None):\n    labels = self.config.engagements_list\n    # union of engagements\n    mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1))\n    n_positives = tf.reduce_sum(tf.cast(mask, tf.int32))\n    batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32)\n    negative_weights = tf.math.divide_no_nan(\n      tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32),\n      tf.cast(batch_size - n_positives, tf.float32),\n    )\n    new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights\n\n    def _split_by_label_concatenate_and_truncate(input_tensor):\n      # takes positive examples and concatenate with negative examples and truncate\n      # DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50)\n      return tf.concat(\n        [\n          input_tensor[mask],\n          input_tensor[tf.math.logical_not(mask)],\n        ],\n        0,\n      )[:batch_size]\n\n    if \"weights\" not in inputs:\n      # add placeholder so logic below applies even if weights aren't present in inputs\n      inputs[\"weights\"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements])\n\n    for tensor in inputs:\n      if tensor == \"weights\":\n        inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1])\n\n      inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor])\n\n    return inputs\n\n\ndef build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):\n  \"\"\"Builds a preprocess model to apply all preprocessing stages.\"\"\"\n  if mode == config_mod.JobMode.INFERENCE:\n    logging.info(\"Not building preprocessors for dataloading since we are in Inference mode.\")\n    return None\n\n  preprocess_models = []\n  if preprocess_config.downsample_negatives:\n    preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives))\n  if preprocess_config.truncate_and_slice:\n    preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice))\n  if preprocess_config.downcast:\n    preprocess_models.append(DownCast(preprocess_config.downcast))\n  if preprocess_config.rectify_labels:\n    preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels))\n  if preprocess_config.extract_features:\n    preprocess_models.append(ExtractFeatures(preprocess_config.extract_features))\n\n  if len(preprocess_models) == 0:\n    raise ValueError(\"No known preprocessor.\")\n\n  class PreprocessModel(tf.keras.Model):\n    def __init__(self, preprocess_models):\n      super().__init__()\n      self.preprocess_models = preprocess_models\n\n    def call(self, inputs, training=None, mask=None):\n      outputs = inputs\n      for model in self.preprocess_models:\n        outputs = model(outputs, training, mask)\n      return outputs\n\n  if len(preprocess_models) > 1:\n    logging.warning(\n      \"With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders.\"\n    )\n  return PreprocessModel(preprocess_models)\n"
  },
  {
    "path": "projects/home/recap/data/tfe_parsing.py",
    "content": "import functools\nimport json\n\nfrom tml.projects.home.recap.data import config as recap_data_config\n\nfrom absl import logging\nimport tensorflow as tf\n\n\nDEFAULTS_MAP = {\"int64_list\": 0, \"float_list\": 0.0, \"bytes_list\": \"\"}\nDTYPE_MAP = {\"int64_list\": tf.int64, \"float_list\": tf.float32, \"bytes_list\": tf.string}\n\n\ndef create_tf_example_schema(\n  data_config: recap_data_config.SegDenseSchema,\n  segdense_schema,\n):\n  \"\"\"Generate schema for deseralizing tf.Example.\n\n  Args:\n    segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).\n    labels: List of strings denoting labels.\n\n  Returns:\n    A dictionary schema suitable for deserializing tf.Example.\n  \"\"\"\n  segdense_config = data_config.seg_dense_schema\n  labels = list(data_config.tasks.keys())\n  used_features = (\n    segdense_config.features + list(segdense_config.renamed_features.values()) + labels\n  )\n  logging.info(used_features)\n\n  tfe_schema = {}\n  for entry in segdense_schema:\n    feature_name = entry[\"feature_name\"]\n\n    if feature_name in used_features:\n      length = entry[\"length\"]\n      dtype = entry[\"dtype\"]\n\n      if feature_name in labels:\n        logging.info(f\"Label: feature name is {feature_name} type is {dtype}\")\n        tfe_schema[feature_name] = tf.io.FixedLenFeature(\n          length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]\n        )\n      elif length == -1:\n        tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])\n      else:\n        tfe_schema[feature_name] = tf.io.FixedLenFeature(\n          length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length\n        )\n  for feature_name in used_features:\n    if feature_name not in tfe_schema:\n      raise ValueError(f\"{feature_name} missing from schema: {segdense_config.schema_path}.\")\n  return tfe_schema\n\n\n@functools.lru_cache(1)\ndef make_mantissa_mask(mask_length: int) -> tf.Tensor:\n  \"\"\"For experimentating with emulating bfloat16 or less precise types.\"\"\"\n  return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)\n\n\ndef mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:\n  \"\"\"For experimentating with emulating bfloat16 or less precise types.\"\"\"\n  mask: tf.Tensor = make_mantissa_mask(mask_length)\n  return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)\n\n\ndef parse_tf_example(\n  serialized_example,\n  tfe_schema,\n  seg_dense_schema_config,\n):\n  \"\"\"Parse serialized tf.Example into dict of tensors.\n\n  Args:\n    serialized_example: Serialized tf.Example to be parsed.\n    tfe_schema: Dictionary schema suitable for deserializing tf.Example.\n\n  Returns:\n    Dictionary of tensors to be used as model input.\n  \"\"\"\n  inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)\n\n  for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():\n    inputs[new_feature_name] = inputs.pop(old_feature_name)\n\n  # This should not actually be used except for experimentation with low precision floats.\n  if \"mask_mantissa_features\" in seg_dense_schema_config:\n    for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():\n      inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)\n\n  # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible\n  # at TF level.\n  # We should not return empty tensors if we dont use embeddings.\n  # Otherwise, it breaks numpy->pt conversion\n  renamed_keys = list(seg_dense_schema_config.renamed_features.keys())\n  for renamed_key in renamed_keys:\n    if \"embedding\" in renamed_key and (renamed_key not in inputs):\n      inputs[renamed_key] = tf.zeros([], tf.float32)\n\n  logging.info(f\"parsed example and inputs are {inputs}\")\n  return inputs\n\n\ndef get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):\n  \"\"\"Placeholder for seg dense.\n\n  In the future, when we use more seg dense variations, we can change this.\n  \"\"\"\n  with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, \"r\") as f:\n    seg_dense_schema = json.load(f)[\"schema\"]\n\n  tf_example_schema = create_tf_example_schema(\n    data_config,\n    seg_dense_schema,\n  )\n\n  logging.info(\"***** TF Example Schema *****\")\n  logging.info(tf_example_schema)\n\n  parse = functools.partial(\n    parse_tf_example,\n    tfe_schema=tf_example_schema,\n    seg_dense_schema_config=data_config.seg_dense_schema,\n  )\n  return parse\n"
  },
  {
    "path": "projects/home/recap/data/util.py",
    "content": "from typing import Mapping, Tuple, Union\nimport torch\nimport torchrec\nimport numpy as np\nimport tensorflow as tf\n\n\ndef keyed_tensor_from_tensors_dict(\n  tensor_map: Mapping[str, torch.Tensor]\n) -> \"torchrec.KeyedTensor\":\n  \"\"\"\n  Convert a dictionary of torch tensor to torchrec keyed tensor\n  Args:\n    tensor_map:\n\n  Returns:\n\n  \"\"\"\n  keys = list(tensor_map.keys())\n  # We expect batch size to be first dim. However, if we get a shape [Batch_size],\n  # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is\n  # [Batch_size x 1].\n  values = [\n    tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)\n    for key in keys\n  ]\n  return torchrec.KeyedTensor.from_tensor_list(keys, values)\n\n\ndef _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n  if tensor.is_sparse:\n    x = tensor.coalesce()  # Ensure that the indices are ordered.\n    lengths = torch.bincount(x.indices()[0])\n    values = x.values()\n  else:\n    values = tensor\n    lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)\n  return values, lengths\n\n\ndef jagged_tensor_from_tensor(tensor: torch.Tensor) -> \"torchrec.JaggedTensor\":\n  \"\"\"\n  Convert a torch tensor to torchrec jagged tensor.\n  Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.\n        For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the\n          dense_shape of the sparse tensor can be arbitrary.\n  Args:\n    tensor: a torch (sparse) tensor.\n  Returns:\n  \"\"\"\n  values, lengths = _compute_jagged_tensor_from_tensor(tensor)\n  return torchrec.JaggedTensor(values=values, lengths=lengths)\n\n\ndef keyed_jagged_tensor_from_tensors_dict(\n  tensor_map: Mapping[str, torch.Tensor]\n) -> \"torchrec.KeyedJaggedTensor\":\n  \"\"\"\n  Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.\n  Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.\n        For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the\n          dense_shape of the sparse tensor can be arbitrary.\n  Args:\n    tensor_map:\n\n  Returns:\n\n  \"\"\"\n\n  if not tensor_map:\n    return torchrec.KeyedJaggedTensor(\n      keys=[],\n      values=torch.zeros(0, dtype=torch.int),\n      lengths=torch.zeros(0, dtype=torch.int),\n    )\n  values = []\n  lengths = []\n  for tensor in tensor_map.values():\n    tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)\n    values.append(torch.squeeze(tensor_val))\n    lengths.append(tensor_len)\n\n  values = torch.cat(values, axis=0)\n  lengths = torch.cat(lengths, axis=0)\n\n  return torchrec.KeyedJaggedTensor(\n    keys=list(tensor_map.keys()),\n    values=values,\n    lengths=lengths,\n  )\n\n\ndef _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:\n  return tf_tensor._numpy()  # noqa\n\n\ndef _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:\n  tensor = _tf_to_numpy(tensor)\n  # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent\n  if tensor.dtype.name == \"bfloat16\":\n    tensor = tensor.astype(np.float32)\n\n  tensor = torch.from_numpy(tensor)\n  if pin_memory:\n    tensor = tensor.pin_memory()\n  return tensor\n\n\ndef sparse_or_dense_tf_to_torch(\n  tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool\n) -> torch.Tensor:\n  if isinstance(tensor, tf.SparseTensor):\n    tensor = torch.sparse_coo_tensor(\n      _dense_tf_to_torch(tensor.indices, pin_memory).t(),\n      _dense_tf_to_torch(tensor.values, pin_memory),\n      torch.Size(_tf_to_numpy(tensor.dense_shape)),\n    )\n  else:\n    tensor = _dense_tf_to_torch(tensor, pin_memory)\n  return tensor\n"
  },
  {
    "path": "projects/home/recap/embedding/config.py",
    "content": "from typing import List, Optional\nimport tml.core.config as base_config\nfrom tml.optimizers import config as optimizer_config\n\nimport pydantic\n\n\nclass EmbeddingSnapshot(base_config.BaseConfig):\n  \"\"\"Configuration for Embedding snapshot\"\"\"\n\n  emb_name: str = pydantic.Field(\n    ..., description=\"Name of the embedding table from the loaded snapshot\"\n  )\n  embedding_snapshot_uri: str = pydantic.Field(\n    ..., description=\"Path to torchsnapshot of the embedding\"\n  )\n\n\n# https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_configs.EmbeddingBagConfig\nclass EmbeddingBagConfig(base_config.BaseConfig):\n  \"\"\"Configuration for EmbeddingBag.\"\"\"\n\n  name: str = pydantic.Field(..., description=\"name of embedding bag\")\n  num_embeddings: int = pydantic.Field(..., description=\"size of embedding dictionary\")\n  embedding_dim: int = pydantic.Field(..., description=\"size of each embedding vector\")\n  pretrained: EmbeddingSnapshot = pydantic.Field(None, description=\"Snapshot properties\")\n  vocab: str = pydantic.Field(\n    None, description=\"Directory to parquet files of mapping from entity ID to table index.\"\n  )\n\n\nclass EmbeddingOptimizerConfig(base_config.BaseConfig):\n  learning_rate: optimizer_config.LearningRate = pydantic.Field(\n    None, description=\"learning rate scheduler for the EBC\"\n  )\n  init_learning_rate: float = pydantic.Field(description=\"initial learning rate for the EBC\")\n  # NB: Only sgd is supported right now and implicitly.\n  # FBGemm only supports simple exact_sgd which only takes LR as an argument.\n\n\nclass LargeEmbeddingsConfig(base_config.BaseConfig):\n  \"\"\"Configuration for EmbeddingBagCollection.\n\n  The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.\n  \"\"\"\n\n  tables: List[EmbeddingBagConfig] = pydantic.Field(..., description=\"list of embedding tables\")\n  optimizer: EmbeddingOptimizerConfig\n  tables_to_log: List[str] = pydantic.Field(\n    None, description=\"list of embedding table names that we want to log during training\"\n  )\n\n\nclass StratifierConfig(base_config.BaseConfig):\n  name: str\n  index: int\n  value: int\n\n\nclass SmallEmbeddingBagConfig(base_config.BaseConfig):\n  \"\"\"Configuration for SmallEmbeddingBag.\"\"\"\n\n  name: str = pydantic.Field(..., description=\"name of embedding bag\")\n  num_embeddings: int = pydantic.Field(..., description=\"size of embedding dictionary\")\n  embedding_dim: int = pydantic.Field(..., description=\"size of each embedding vector\")\n  index: int = pydantic.Field(..., description=\"index in the discrete tensor to look for\")\n\n\nclass SmallEmbeddingBagConfig(base_config.BaseConfig):\n  \"\"\"Configuration for SmallEmbeddingBag.\"\"\"\n\n  name: str = pydantic.Field(..., description=\"name of embedding bag\")\n  num_embeddings: int = pydantic.Field(..., description=\"size of embedding dictionary\")\n  embedding_dim: int = pydantic.Field(..., description=\"size of each embedding vector\")\n  index: int = pydantic.Field(..., description=\"index in the discrete tensor to look for\")\n\n\nclass SmallEmbeddingsConfig(base_config.BaseConfig):\n  \"\"\"Configuration for SmallEmbeddingConfig.\n\n  Here we can use discrete features that already are present in our TFRecords generated using\n  segdense conversion as \"home_recap_2022_discrete__segdense_vals\" which are available in\n  the model as \"discrete_features\", and embed a user-defined set of them with configurable\n  dimensions and vocabulary sizes.\n\n  Compared with LargeEmbedding, this config is for small embedding tables that can fit inside\n  the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at\n  serving time due to size (>>1 GB).\n\n  This small embeddings table uses the same optimizer as the rest of the model.\"\"\"\n\n  tables: List[SmallEmbeddingBagConfig] = pydantic.Field(\n    ..., description=\"list of embedding tables\"\n  )\n"
  },
  {
    "path": "projects/home/recap/main.py",
    "content": "import datetime\nimport os\nfrom typing import Callable, List, Optional, Tuple\nimport tensorflow as tf\n\nimport tml.common.checkpointing.snapshot as snapshot_lib\nfrom tml.common.device import setup_and_get_device\nfrom tml.core import config as tml_config_mod\nimport tml.core.custom_training_loop as ctl\nfrom tml.core import debug_training_loop\nfrom tml.core import losses\nfrom tml.core.loss_type import LossType\nfrom tml.model import maybe_shard_model\n\n\nimport tml.projects.home.recap.data.dataset as ds\nimport tml.projects.home.recap.config as recap_config_mod\nimport tml.projects.home.recap.optimizer as optimizer_mod\n\n\n# from tml.projects.home.recap import feature\nimport tml.projects.home.recap.model as model_mod\nimport torchmetrics as tm\nimport torch\nimport torch.distributed as dist\nfrom torchrec.distributed.model_parallel import DistributedModelParallel\n\nfrom absl import app, flags, logging\n\nflags.DEFINE_string(\"config_path\", None, \"Path to hyperparameters for model.\")\nflags.DEFINE_bool(\"debug_loop\", False, \"Run with debug loop (slow)\")\n\nFLAGS = flags.FLAGS\n\n\ndef run(unused_argv: str, data_service_dispatcher: Optional[str] = None):\n  print(\"#\" * 100)\n\n  config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)\n  logging.info(\"Config: %s\", config.pretty_print())\n\n  device = setup_and_get_device()\n\n  # Always enable tensorfloat on supported devices.\n  torch.backends.cuda.matmul.allow_tf32 = True\n  torch.backends.cudnn.allow_tf32 = True\n\n  loss_fn = losses.build_multi_task_loss(\n    loss_type=LossType.BCE_WITH_LOGITS,\n    tasks=list(config.model.tasks.keys()),\n    pos_weights=[task.pos_weight for task in config.model.tasks.values()],\n  )\n\n  # Since the prod model doesn't use large embeddings, for now we won't support them.\n  assert config.model.large_embeddings is None\n\n  train_dataset = ds.RecapDataset(\n    data_config=config.train_data,\n    dataset_service=data_service_dispatcher,\n    mode=recap_config_mod.JobMode.TRAIN,\n    compression=config.train_data.dataset_service_compression,\n    vocab_mapper=None,\n    repeat=True,\n  )\n\n  train_iterator = iter(train_dataset.to_dataloader())\n\n  torch_element_spec = train_dataset.torch_element_spec\n\n  model = model_mod.create_ranking_model(\n    data_spec=torch_element_spec[0],\n    config=config,\n    loss_fn=loss_fn,\n    device=device,\n  )\n\n  optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None)\n\n  model = maybe_shard_model(model, device)\n\n  datetime_str = datetime.datetime.now().strftime(\"%Y_%m_%d_%H_%M\")\n  print(f\"{datetime_str}\\n\", end=\"\")\n\n  if FLAGS.debug_loop:\n    logging.warning(\"Running debug mode, slow!\")\n    train_mod = debug_training_loop\n  else:\n    train_mod = ctl\n\n  train_mod.train(\n    model=model,\n    optimizer=optimizer,\n    device=device,\n    save_dir=config.training.save_dir,\n    logging_interval=config.training.train_log_every_n,\n    train_steps=config.training.num_train_steps,\n    checkpoint_frequency=config.training.checkpoint_every_n,\n    dataset=train_iterator,\n    worker_batch_size=config.train_data.global_batch_size,\n    enable_amp=False,\n    initial_checkpoint_dir=config.training.initial_checkpoint_dir,\n    gradient_accumulation=config.training.gradient_accumulation,\n    scheduler=scheduler,\n  )\n\n\nif __name__ == \"__main__\":\n  app.run(run)\n"
  },
  {
    "path": "projects/home/recap/model/__init__.py",
    "content": "from tml.projects.home.recap.model.entrypoint import (\n  create_ranking_model,\n  sanitize,\n  unsanitize,\n  MultiTaskRankingModel,\n)\nfrom tml.projects.home.recap.model.model_and_loss import ModelAndLoss\n"
  },
  {
    "path": "projects/home/recap/model/config.py",
    "content": "\"\"\"Configuration for the main Recap model.\"\"\"\n\nimport enum\nfrom typing import List, Optional, Dict\n\nimport tml.core.config as base_config\nfrom tml.projects.home.recap.embedding import config as embedding_config\n\nimport pydantic\n\n\nclass DropoutConfig(base_config.BaseConfig):\n  \"\"\"Configuration for the dropout layer.\"\"\"\n\n  rate: pydantic.PositiveFloat = pydantic.Field(\n    0.1, description=\"Fraction of inputs to be dropped.\"\n  )\n\n\nclass LayerNormConfig(base_config.BaseConfig):\n  \"\"\"Configruation for the layer normalization.\"\"\"\n\n  epsilon: float = pydantic.Field(\n    1e-3, description=\"Small float added to variance to avoid dividing by zero.\"\n  )\n  axis: int = pydantic.Field(-1, description=\"Axis or axes to normalize across.\")\n  center: bool = pydantic.Field(True, description=\"Whether to add learnable center.\")\n  scale: bool = pydantic.Field(True, description=\"Whether to add learnable scale.\")\n\n\nclass BatchNormConfig(base_config.BaseConfig):\n  \"\"\"Configuration of the batch normalization layer.\"\"\"\n\n  epsilon: pydantic.PositiveFloat = 1e-5\n  momentum: pydantic.PositiveFloat = 0.9\n  training_mode_at_inference_time: bool = False\n  use_renorm: bool = False\n  center: bool = pydantic.Field(True, description=\"Whether to add learnable center.\")\n  scale: bool = pydantic.Field(True, description=\"Whether to add learnable scale.\")\n\n\nclass DenseLayerConfig(base_config.BaseConfig):\n  layer_size: pydantic.PositiveInt\n  dropout: DropoutConfig = pydantic.Field(None, description=\"Optional dropout config for layer.\")\n\n\nclass MlpConfig(base_config.BaseConfig):\n  \"\"\"Configuration for MLP model.\"\"\"\n\n  layer_sizes: List[pydantic.PositiveInt] = pydantic.Field(None, one_of=\"mlp_layer_definition\")\n  layers: List[DenseLayerConfig] = pydantic.Field(None, one_of=\"mlp_layer_definition\")\n\n\nclass BatchNormConfig(base_config.BaseConfig):\n  \"\"\"Configuration for the batch norm layer.\"\"\"\n\n  affine: bool = pydantic.Field(True, description=\"Use affine transformation.\")\n  momentum: pydantic.PositiveFloat = pydantic.Field(\n    0.1, description=\"Forgetting parameter in moving average.\"\n  )\n\n\nclass DoubleNormLogConfig(base_config.BaseConfig):\n  batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)\n  clip_magnitude: float = pydantic.Field(\n    5.0, description=\"Threshold to clip the normalized input values.\"\n  )\n  layer_norm_config: Optional[LayerNormConfig] = pydantic.Field(None)\n\n\nclass Log1pAbsConfig(base_config.BaseConfig):\n  \"\"\"Simple configuration where only the log transform is performed.\"\"\"\n\n\nclass ClipLog1pAbsConfig(base_config.BaseConfig):\n  clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(\n    3e38, description=\"Threshold to clip the input values.\"\n  )\n\n\nclass ZScoreLogConfig(base_config.BaseConfig):\n  analysis_path: str\n  schema_path: str = pydantic.Field(\n    None,\n    description=\"Schema path which feaure statistics are generated with. Can be different from scehma in data config.\",\n  )\n  clip_magnitude: float = pydantic.Field(\n    5.0, description=\"Threshold to clip the normalized input values.\"\n  )\n  use_batch_norm: bool = pydantic.Field(\n    False, description=\"Option to use batch normalization on the inputs.\"\n  )\n  use_renorm: bool = pydantic.Field(\n    False, description=\"Option to use batch renormalization for trainig and serving consistency.\"\n  )\n  use_bq_stats: bool = pydantic.Field(\n    False, description=\"Option to load the partitioned json files from BQ as statistics.\"\n  )\n\n\nclass FeaturizationConfig(base_config.BaseConfig):\n  \"\"\"Configuration for featurization.\"\"\"\n\n  log1p_abs_config: Log1pAbsConfig = pydantic.Field(None, one_of=\"featurization\")\n  clip_log1p_abs_config: ClipLog1pAbsConfig = pydantic.Field(None, one_of=\"featurization\")\n  z_score_log_config: ZScoreLogConfig = pydantic.Field(None, one_of=\"featurization\")\n  double_norm_log_config: DoubleNormLogConfig = pydantic.Field(None, one_of=\"featurization\")\n  feature_names_to_concat: List[str] = pydantic.Field(\n    [\"binary\"], description=\"Feature names to concatenate as raw values with continuous features.\"\n  )\n\n\nclass DropoutConfig(base_config.BaseConfig):\n  \"\"\"Configuration for the dropout layer.\"\"\"\n\n  rate: pydantic.PositiveFloat = pydantic.Field(\n    0.1, description=\"Fraction of inputs to be dropped.\"\n  )\n\n\nclass MlpConfig(base_config.BaseConfig):\n  \"\"\"Configuration for MLP model.\"\"\"\n\n  layer_sizes: List[pydantic.PositiveInt]\n  batch_norm: BatchNormConfig = pydantic.Field(\n    None, description=\"Optional batch norm configuration.\"\n  )\n  dropout: DropoutConfig = pydantic.Field(None, description=\"Optional dropout configuration.\")\n  final_layer_activation: bool = pydantic.Field(\n    False, description=\"Whether to include activation on final layer.\"\n  )\n\n\nclass DcnConfig(base_config.BaseConfig):\n  \"\"\"Config for DCN model.\"\"\"\n\n  poly_degree: pydantic.PositiveInt\n  projection_dim: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"Factorizes main DCN matmul with projection.\"\n  )\n\n  parallel_mlp: Optional[MlpConfig] = pydantic.Field(\n    None, description=\"Config for the mlp if used. If None, only the cross layers are used.\"\n  )\n  use_parallel: bool = pydantic.Field(True, description=\"Whether to use parallel DCN.\")\n\n  output_mlp: Optional[MlpConfig] = pydantic.Field(None, description=\"Config for the output mlp.\")\n\n\nclass MaskBlockConfig(base_config.BaseConfig):\n  output_size: int\n  reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(\n    None, one_of=\"aggregation_size\"\n  )\n  aggregation_size: Optional[pydantic.PositiveInt] = pydantic.Field(\n    None, description=\"Specify the aggregation size directly.\", one_of=\"aggregation_size\"\n  )\n  input_layer_norm: bool\n\n\nclass MaskNetConfig(base_config.BaseConfig):\n  mask_blocks: List[MaskBlockConfig]\n  mlp: Optional[MlpConfig] = pydantic.Field(None, description=\"MLP Configuration for parallel\")\n  use_parallel: bool = pydantic.Field(False, description=\"Whether to use parallel MaskNet.\")\n\n\nclass PositionDebiasConfig(base_config.BaseConfig):\n  \"\"\"\n  Configuration for Position Debias.\n  \"\"\"\n\n  max_position: int = pydantic.Field(256, description=\"Bucket all later positions.\")\n  num_dims: pydantic.PositiveInt = pydantic.Field(\n    64, description=\"Number of dimensions in embedding.\"\n  )\n  drop_probability: float = pydantic.Field(0.5, description=\"Probability of dropping position.\")\n\n  # Currently it should be 51 based on dataset being tested at the time of writing this model\n  # However, no default provided here to make sure user of the model is aware of its importance.\n  position_feature_index: int = pydantic.Field(\n    description=\"The index of the position feature in the discrete features\"\n  )\n\n\nclass AffineMap(base_config.BaseConfig):\n  \"\"\"An affine map that scales the logits into the appropriate range.\"\"\"\n\n  scale: float = pydantic.Field(1.0)\n  bias: float = pydantic.Field(0.0)\n\n\nclass DLRMConfig(base_config.BaseConfig):\n  bottom_mlp: MlpConfig = pydantic.Field(\n    ...,\n    description=\"Bottom mlp, the output to be combined with sparse features and feed to interaction\",\n  )\n  top_mlp: MlpConfig = pydantic.Field(..., description=\"Top mlp, generate the final output\")\n\n\nclass TaskModel(base_config.BaseConfig):\n  mlp_config: MlpConfig = pydantic.Field(None, one_of=\"architecture\")\n  dcn_config: DcnConfig = pydantic.Field(None, one_of=\"architecture\")\n  dlrm_config: DLRMConfig = pydantic.Field(None, one_of=\"architecture\")\n  mask_net_config: MaskNetConfig = pydantic.Field(None, one_of=\"architecture\")\n\n  affine_map: AffineMap = pydantic.Field(\n    None,\n    description=\"Affine map applied to logits so we can represent a broader range of probabilities.\",\n  )\n  # DANGER DANGER: not implemented yet.\n  # loss_weight: float = pydantic.Field(1.0, description=\"Weight for task in loss.\")\n  pos_weight: float = pydantic.Field(1.0, description=\"Weight of positive in loss.\")\n\n\nclass MultiTaskType(str, enum.Enum):\n  SHARE_NONE = \"share_none\"  # Tasks are separate.\n  SHARE_ALL = \"share_all\"  # Tasks share same backbone.\n  SHARE_PARTIAL = \"share_partial\"  # Tasks share some backbone, but have their own portions.\n\n\nclass ModelConfig(base_config.BaseConfig):\n  \"\"\"Specify model architecture.\"\"\"\n\n  tasks: Dict[str, TaskModel] = pydantic.Field(\n    description=\"Specification of architecture per task.\"\n  )\n\n  large_embeddings: embedding_config.LargeEmbeddingsConfig = pydantic.Field(None)\n  small_embeddings: embedding_config.SmallEmbeddingsConfig = pydantic.Field(None)\n  # Not implemented yet.\n  # multi_task_loss_reduction_fn: str = \"mean\"\n\n  position_debias_config: PositionDebiasConfig = pydantic.Field(\n    default=None, description=\"position debias model configuration\"\n  )\n\n  featurization_config: FeaturizationConfig = pydantic.Field(None)\n\n  multi_task_type: MultiTaskType = pydantic.Field(\n    MultiTaskType.SHARE_NONE, description=\"Multi task architecture\"\n  )\n\n  backbone: TaskModel = pydantic.Field(None, description=\"Type of architecture for the backbone.\")\n  stratifiers: List[embedding_config.StratifierConfig] = pydantic.Field(\n    default=None, description=\"Discrete features and values to stratify metrics by.\"\n  )\n\n  @pydantic.root_validator()\n  def _validate_mtl(cls, values):\n    if values.get(\"multi_task_type\", None) is None:\n      return values\n    elif values[\"multi_task_type\"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:\n      if values.get(\"backbone\", None) is None:\n        raise ValueError(\"Require `backbone` for SHARE_ALL and SHARE_PARTIAL.\")\n    elif values[\"multi_task_type\"] in [\n      MultiTaskType.SHARE_NONE,\n    ]:\n      if values.get(\"backbone\", None) is not None:\n        raise ValueError(\"Can not have backbone if the share type is SHARE_NONE\")\n    return values\n"
  },
  {
    "path": "projects/home/recap/model/entrypoint.py",
    "content": "from __future__ import annotations\n\nfrom absl import logging\nimport torch\nfrom typing import Optional, Callable, Mapping, Dict, Sequence, TYPE_CHECKING\nfrom tml.projects.home.recap.model import feature_transform\nfrom tml.projects.home.recap.model import config as model_config_mod\nfrom tml.projects.home.recap.model import mlp\nfrom tml.projects.home.recap.model import mask_net\nfrom tml.projects.home.recap.model import numeric_calibration\nfrom tml.projects.home.recap.model.model_and_loss import ModelAndLoss\nimport tml.projects.home.recap.model.config as model_config_mod\n\nif TYPE_CHECKING:\n  from tml.projects.home.recap import config as config_mod\n  from tml.projects.home.recap.data.config import RecapDataConfig\n  from tml.projects.home.recap.model.config import ModelConfig\n\n\ndef sanitize(task_name):\n  return task_name.replace(\".\", \"__\")\n\n\ndef unsanitize(sanitized_task_name):\n  return sanitized_task_name.replace(\"__\", \".\")\n\n\ndef _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):\n  \"\"\" \"Builds a model for a single task\"\"\"\n  if task.mlp_config:\n    return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)\n  elif task.dcn_config:\n    return dcn.Dcn(dcn_config=task.dcn_config, in_features=input_shape)\n  elif task.mask_net_config:\n    return mask_net.MaskNet(mask_net_config=task.mask_net_config, in_features=input_shape)\n  else:\n    raise ValueError(\"This should never be reached.\")\n\n\nclass MultiTaskRankingModel(torch.nn.Module):\n  \"\"\"Multi-task ranking model.\"\"\"\n\n  def __init__(\n    self,\n    input_shapes: Mapping[str, torch.Size],\n    config: ModelConfig,\n    data_config: RecapDataConfig,\n    return_backbone: bool = False,\n  ):\n    \"\"\"Constructor for Multi task learning.\n\n    Assumptions made:\n    1. Tasks specified in data config match model architecture.\n\n    These are all validated in config.\n    \"\"\"\n    super().__init__()\n\n    self._config = config\n    self._data_config = data_config\n\n    self._preprocessor = feature_transform.build_features_preprocessor(\n      config.featurization_config, input_shapes\n    )\n\n    self.return_backbone = return_backbone\n\n    self.embeddings = None\n    self.small_embeddings = None\n    embedding_dims = 0\n    if config.large_embeddings:\n      from large_embeddings.models.learnable_embeddings import LargeEmbeddings\n\n      self.embeddings = LargeEmbeddings(large_embeddings_config=config.large_embeddings)\n\n      embedding_dims += sum([table.embedding_dim for table in config.large_embeddings.tables])\n      logging.info(f\"Emb dim: {embedding_dims}\")\n\n    if config.small_embeddings:\n      self.small_embeddings = SmallEmbedding(config.small_embeddings)\n      embedding_dims += sum([table.embedding_dim for table in config.small_embeddings.tables])\n      logging.info(f\"Emb dim (with small embeddings): {embedding_dims}\")\n\n    if \"user_embedding\" in data_config.seg_dense_schema.renamed_features:\n      embedding_dims += input_shapes[\"user_embedding\"][-1]\n      self._user_embedding_layer_norm = torch.nn.LayerNorm(input_shapes[\"user_embedding\"][-1])\n    else:\n      self._user_embedding_layer_norm = None\n    if \"user_eng_embedding\" in data_config.seg_dense_schema.renamed_features:\n      embedding_dims += input_shapes[\"user_eng_embedding\"][-1]\n      self._user_eng_embedding_layer_norm = torch.nn.LayerNorm(\n        input_shapes[\"user_eng_embedding\"][-1]\n      )\n    else:\n      self._user_eng_embedding_layer_norm = None\n    if \"author_embedding\" in data_config.seg_dense_schema.renamed_features:\n      embedding_dims += input_shapes[\"author_embedding\"][-1]\n      self._author_embedding_layer_norm = torch.nn.LayerNorm(input_shapes[\"author_embedding\"][-1])\n    else:\n      self._author_embedding_layer_norm = None\n\n    input_dims = input_shapes[\"continuous\"][-1] + input_shapes[\"binary\"][-1] + embedding_dims\n\n    if config.position_debias_config:\n      self.position_debias_model = PositionDebias(config.position_debias_config)\n      input_dims += self.position_debias_model.out_features\n    else:\n      self.position_debias_model = None\n    logging.info(f\"input dim: {input_dims}\")\n\n    if config.multi_task_type in [\n      model_config_mod.MultiTaskType.SHARE_ALL,\n      model_config_mod.MultiTaskType.SHARE_PARTIAL,\n    ]:\n      self._backbone = _build_single_task_model(config.backbone, input_dims)\n    else:\n      self._backbone = None\n\n    _towers: Dict[str, torch.nn.Module] = {}\n    _calibrators: Dict[str, torch.nn.Module] = {}\n    _affine_maps: Dict[str, torch.nn.Module] = {}\n\n    for task_name, task_architecture in config.tasks.items():\n      safe_name = sanitize(task_name)\n\n      # Complex input dimension calculation.\n      if config.multi_task_type == model_config_mod.MultiTaskType.SHARE_NONE:\n        num_inputs = input_dims\n      elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:\n        num_inputs = self._backbone.out_features\n      elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:\n        num_inputs = input_dims + self._backbone.out_features\n      else:\n        raise ValueError(\"Unreachable branch of enum.\")\n\n      # Annoyingly, ModuleDict doesn't allow . inside key names.\n      _towers[safe_name] = _build_single_task_model(task_architecture, num_inputs)\n\n      if task_architecture.affine_map:\n        affine_map = torch.nn.Linear(1, 1)\n        affine_map.weight.data = torch.tensor([[task_architecture.affine_map.scale]])\n        affine_map.bias.data = torch.tensor([task_architecture.affine_map.bias])\n        _affine_maps[safe_name] = affine_map\n      else:\n        _affine_maps[safe_name] = torch.nn.Identity()\n\n      _calibrators[safe_name] = numeric_calibration.NumericCalibration(\n        pos_downsampling_rate=data_config.tasks[task_name].pos_downsampling_rate,\n        neg_downsampling_rate=data_config.tasks[task_name].neg_downsampling_rate,\n      )\n\n    self._task_names = list(config.tasks.keys())\n    self._towers = torch.nn.ModuleDict(_towers)\n    self._affine_maps = torch.nn.ModuleDict(_affine_maps)\n    self._calibrators = torch.nn.ModuleDict(_calibrators)\n\n    self._counter = torch.autograd.Variable(torch.tensor(0), requires_grad=False)\n\n  def forward(\n    self,\n    continuous_features: torch.Tensor,\n    binary_features: torch.Tensor,\n    discrete_features: Optional[torch.Tensor] = None,\n    sparse_features=None,  # Optional[KeyedJaggedTensor]\n    user_embedding: Optional[torch.Tensor] = None,\n    user_eng_embedding: Optional[torch.Tensor] = None,\n    author_embedding: Optional[torch.Tensor] = None,\n    labels: Optional[torch.Tensor] = None,\n    weights: Optional[torch.Tensor] = None,\n  ):\n    concat_dense_features = [\n      self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)\n    ]\n\n    if self.embeddings:\n      concat_dense_features.append(self.embeddings(sparse_features))\n\n    # Twhin embedding layer norms\n    if self.small_embeddings:\n      if discrete_features is None:\n        raise ValueError(\n          \"Forward arg discrete_features is None, but since small_embeddings are used, a Tensor is expected.\"\n        )\n      concat_dense_features.append(self.small_embeddings(discrete_features))\n\n    if self._user_embedding_layer_norm:\n      if user_embedding is None:\n        raise ValueError(\n          \"Forward arg user_embedding is None, but since Twhin user_embeddings are used by the model, a Tensor is expected.\"\n        )\n      concat_dense_features.append(self._user_embedding_layer_norm(user_embedding))\n\n    if self._user_eng_embedding_layer_norm:\n      if user_eng_embedding is None:\n        raise ValueError(\n          \"Forward arg user_eng_embedding is None, but since Twhin user_eng_embeddings are used by the model, a Tensor is expected.\"\n        )\n      concat_dense_features.append(self._user_eng_embedding_layer_norm(user_eng_embedding))\n\n    if self._author_embedding_layer_norm:\n      if author_embedding is None:\n        raise ValueError(\n          \"Forward arg author_embedding is None, but since Twhin author_embeddings are used by the model, a Tensor is expected.\"\n        )\n      concat_dense_features.append(self._author_embedding_layer_norm(author_embedding))\n\n    if self.position_debias_model:\n      if discrete_features is None:\n        raise ValueError(\n          \"Forward arg discrete_features is None, but since position_debias_model is used, a Tensor is expected.\"\n        )\n      concat_dense_features.append(self.position_debias_model(discrete_features))\n\n    if discrete_features is not None and not (self.position_debias_model or self.small_embeddings):\n      logging.warning(\"Forward arg discrete_features is passed, but never used.\")\n\n    concat_dense_features = torch.cat(concat_dense_features, dim=1)\n\n    if self._backbone:\n      if self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:\n        net = self._backbone(concat_dense_features)[\"output\"]\n      elif self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:\n        net = torch.cat(\n          [concat_dense_features, self._backbone(concat_dense_features)[\"output\"]], dim=1\n        )\n    else:\n      net = concat_dense_features\n\n    backbone_result = net\n\n    all_logits = []\n    all_probabilities = []\n    all_calibrated_probabilities = []\n\n    for task_name in self._task_names:\n      safe_name = sanitize(task_name)\n      tower_outputs = self._towers[safe_name](net)\n      logits = tower_outputs[\"output\"]\n      scaled_logits = self._affine_maps[safe_name](logits)\n      probabilities = torch.sigmoid(scaled_logits)\n      calibrated_probabilities = self._calibrators[safe_name](probabilities)\n\n      all_logits.append(scaled_logits)\n      all_probabilities.append(probabilities)\n      all_calibrated_probabilities.append(calibrated_probabilities)\n\n    results = {\n      \"logits\": torch.squeeze(torch.stack(all_logits, dim=1), dim=-1),\n      \"probabilities\": torch.squeeze(torch.stack(all_probabilities, dim=1), dim=-1),\n      \"calibrated_probabilities\": torch.squeeze(\n        torch.stack(all_calibrated_probabilities, dim=1), dim=-1\n      ),\n    }\n\n    # Returning the backbone is intended for stitching post-tf conversion\n    # Leaving this on will ~200x the size of the output\n    # and could slow things down\n    if self.return_backbone:\n      results[\"backbone\"] = backbone_result\n\n    return results\n\n\ndef create_ranking_model(\n  data_spec,\n  # Used for planner to be batch size aware.\n  config: config_mod.RecapConfig,\n  device: torch.device,\n  loss_fn: Optional[Callable] = None,\n  data_config=None,\n  return_backbone=False,\n):\n\n  if list(config.model.tasks.values())[0].dlrm_config:\n    raise NotImplementedError()\n    model = EmbeddingRankingModel(\n      input_shapes=data_spec,\n      config=all_config.model,\n      data_config=all_config.train_data,\n    )\n  else:\n    model = MultiTaskRankingModel(\n      input_shapes=data_spec,\n      config=config.model,\n      data_config=data_config if data_config is not None else config.train_data,\n      return_backbone=return_backbone,\n    )\n\n  logging.info(\"***** Model Architecture *****\")\n  logging.info(model)\n\n  logging.info(\"***** Named Parameters *****\")\n  for elem in model.named_parameters():\n    logging.info(elem[0])\n\n  if loss_fn:\n    logging.info(\"***** Wrapping in loss *****\")\n    model = ModelAndLoss(\n      model=model,\n      loss_fn=loss_fn,\n      stratifiers=config.model.stratifiers,\n    )\n\n  return model\n"
  },
  {
    "path": "projects/home/recap/model/feature_transform.py",
    "content": "from typing import Mapping, Sequence, Union\n\nfrom tml.projects.home.recap.model.config import (\n  BatchNormConfig,\n  DoubleNormLogConfig,\n  FeaturizationConfig,\n  LayerNormConfig,\n)\n\nimport torch\n\n\ndef log_transform(x: torch.Tensor) -> torch.Tensor:\n  \"\"\"Safe log transform that works across both negative, zero, and positive floats.\"\"\"\n  return torch.sign(x) * torch.log1p(torch.abs(x))\n\n\nclass BatchNorm(torch.nn.Module):\n  def __init__(self, num_features: int, config: BatchNormConfig):\n    super().__init__()\n    self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    return self.layer(x)\n\n\nclass LayerNorm(torch.nn.Module):\n  def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):\n    super().__init__()\n    if config.axis != -1:\n      raise NotImplementedError\n    if config.center != config.scale:\n      raise ValueError(\n        f\"Center and scale must match in torch, received {config.center}, {config.scale}\"\n      )\n    self.layer = torch.nn.LayerNorm(\n      normalized_shape, eps=config.epsilon, elementwise_affine=config.center\n    )\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    return self.layer(x)\n\n\nclass Log1pAbs(torch.nn.Module):\n  def __init__(self):\n    super().__init__()\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    return log_transform(x)\n\n\nclass InputNonFinite(torch.nn.Module):\n  def __init__(self, fill_value: float = 0):\n    super().__init__()\n\n    self.register_buffer(\n      \"fill_value\", torch.as_tensor(fill_value, dtype=torch.float32), persistent=False\n    )\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    return torch.where(torch.isfinite(x), x, self.fill_value)\n\n\nclass Clamp(torch.nn.Module):\n  def __init__(self, min_value: float, max_value: float):\n    super().__init__()\n    # Using buffer to make sure they are on correct device (and not moved every time).\n    # Will also be part of state_dict.\n    self.register_buffer(\n      \"min_value\", torch.as_tensor(min_value, dtype=torch.float32), persistent=True\n    )\n    self.register_buffer(\n      \"max_value\", torch.as_tensor(max_value, dtype=torch.float32), persistent=True\n    )\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    return torch.clamp(x, min=self.min_value, max=self.max_value)\n\n\nclass DoubleNormLog(torch.nn.Module):\n  \"\"\"Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.\"\"\"\n\n  def __init__(\n    self,\n    input_shapes: Mapping[str, Sequence[int]],\n    config: DoubleNormLogConfig,\n  ):\n    super().__init__()\n\n    _before_concat_layers = [\n      InputNonFinite(),\n      Log1pAbs(),\n    ]\n    if config.batch_norm_config:\n      _before_concat_layers.append(\n        BatchNorm(input_shapes[\"continuous\"][-1], config.batch_norm_config)\n      )\n    _before_concat_layers.append(\n      Clamp(min_value=-config.clip_magnitude, max_value=config.clip_magnitude)\n    )\n    self._before_concat_layers = torch.nn.Sequential(*_before_concat_layers)\n\n    self.layer_norm = None\n    if config.layer_norm_config:\n      last_dim = input_shapes[\"continuous\"][-1] + input_shapes[\"binary\"][-1]\n      self.layer_norm = LayerNorm(last_dim, config.layer_norm_config)\n\n  def forward(\n    self, continuous_features: torch.Tensor, binary_features: torch.Tensor\n  ) -> torch.Tensor:\n    x = self._before_concat_layers(continuous_features)\n    x = torch.cat([x, binary_features], dim=1)\n    if self.layer_norm:\n      return self.layer_norm(x)\n    return x\n\n\ndef build_features_preprocessor(\n  config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]\n):\n  \"\"\"Trivial right now, but we will change in the future.\"\"\"\n  return DoubleNormLog(input_shapes, config.double_norm_log_config)\n"
  },
  {
    "path": "projects/home/recap/model/mask_net.py",
    "content": "\"\"\"MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619).\"\"\"\n\nfrom tml.projects.home.recap.model import config, mlp\n\nimport torch\n\n\ndef _init_weights(module):\n  if isinstance(module, torch.nn.Linear):\n    torch.nn.init.xavier_uniform_(module.weight)\n    torch.nn.init.constant_(module.bias, 0)\n\n\nclass MaskBlock(torch.nn.Module):\n  def __init__(\n    self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int\n  ) -> None:\n    super(MaskBlock, self).__init__()\n    self.mask_block_config = mask_block_config\n    output_size = mask_block_config.output_size\n\n    if mask_block_config.input_layer_norm:\n      self._input_layer_norm = torch.nn.LayerNorm(input_dim)\n    else:\n      self._input_layer_norm = None\n\n    if mask_block_config.reduction_factor:\n      aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)\n    elif mask_block_config.aggregation_size is not None:\n      aggregation_size = mask_block_config.aggregation_size\n    else:\n      raise ValueError(\"Need one of reduction factor or aggregation size.\")\n\n    self._mask_layer = torch.nn.Sequential(\n      torch.nn.Linear(mask_input_dim, aggregation_size),\n      torch.nn.ReLU(),\n      torch.nn.Linear(aggregation_size, input_dim),\n    )\n    self._mask_layer.apply(_init_weights)\n    self._hidden_layer = torch.nn.Linear(input_dim, output_size)\n    self._hidden_layer.apply(_init_weights)\n    self._layer_norm = torch.nn.LayerNorm(output_size)\n\n  def forward(self, net: torch.Tensor, mask_input: torch.Tensor):\n    if self._input_layer_norm:\n      net = self._input_layer_norm(net)\n    hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))\n    return self._layer_norm(hidden_layer_output)\n\n\nclass MaskNet(torch.nn.Module):\n  def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):\n    super().__init__()\n    self.mask_net_config = mask_net_config\n    mask_blocks = []\n\n    if mask_net_config.use_parallel:\n      total_output_mask_blocks = 0\n      for mask_block_config in mask_net_config.mask_blocks:\n        mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))\n        total_output_mask_blocks += mask_block_config.output_size\n      self._mask_blocks = torch.nn.ModuleList(mask_blocks)\n    else:\n      input_size = in_features\n      for mask_block_config in mask_net_config.mask_blocks:\n        mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))\n        input_size = mask_block_config.output_size\n\n      self._mask_blocks = torch.nn.ModuleList(mask_blocks)\n      total_output_mask_blocks = mask_block_config.output_size\n\n    if mask_net_config.mlp:\n      self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)\n      self.out_features = mask_net_config.mlp.layer_sizes[-1]\n    else:\n      self.out_features = total_output_mask_blocks\n    self.shared_size = total_output_mask_blocks\n\n  def forward(self, inputs: torch.Tensor):\n    if self.mask_net_config.use_parallel:\n      mask_outputs = []\n      for mask_layer in self._mask_blocks:\n        mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))\n      # Share the outputs of the MaskBlocks.\n      all_mask_outputs = torch.cat(mask_outputs, dim=1)\n      output = (\n        all_mask_outputs\n        if self.mask_net_config.mlp is None\n        else self._dense_layers(all_mask_outputs)[\"output\"]\n      )\n      return {\"output\": output, \"shared_layer\": all_mask_outputs}\n    else:\n      net = inputs\n      for mask_layer in self._mask_blocks:\n        net = mask_layer(net=net, mask_input=inputs)\n      # Share the output of the stacked MaskBlocks.\n      output = net if self.mask_net_config.mlp is None else self._dense_layers[net][\"output\"]\n      return {\"output\": output, \"shared_layer\": net}\n"
  },
  {
    "path": "projects/home/recap/model/mlp.py",
    "content": "\"\"\"MLP feed forward stack in torch.\"\"\"\n\nfrom tml.projects.home.recap.model.config import MlpConfig\n\nimport torch\nfrom absl import logging\n\n\ndef _init_weights(module):\n  if isinstance(module, torch.nn.Linear):\n    torch.nn.init.xavier_uniform_(module.weight)\n    torch.nn.init.constant_(module.bias, 0)\n\n\nclass Mlp(torch.nn.Module):\n  def __init__(self, in_features: int, mlp_config: MlpConfig):\n    super().__init__()\n    self._mlp_config = mlp_config\n    input_size = in_features\n    layer_sizes = mlp_config.layer_sizes\n    modules = []\n    for layer_size in layer_sizes[:-1]:\n      modules.append(torch.nn.Linear(input_size, layer_size, bias=True))\n\n      if mlp_config.batch_norm:\n        modules.append(\n          torch.nn.BatchNorm1d(\n            layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum\n          )\n        )\n\n      modules.append(torch.nn.ReLU())\n\n      if mlp_config.dropout:\n        modules.append(torch.nn.Dropout(mlp_config.dropout.rate))\n\n      input_size = layer_size\n    modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True))\n    if mlp_config.final_layer_activation:\n      modules.append(torch.nn.ReLU())\n    self.layers = torch.nn.ModuleList(modules)\n    self.layers.apply(_init_weights)\n\n  def forward(self, x: torch.Tensor) -> torch.Tensor:\n    net = x\n    for i, layer in enumerate(self.layers):\n      net = layer(net)\n      if i == 1:  # Share the first (widest?) set of activations for other applications.\n        shared_layer = net\n    return {\"output\": net, \"shared_layer\": shared_layer}\n\n  @property\n  def shared_size(self):\n    return self._mlp_config.layer_sizes[-1]\n\n  @property\n  def out_features(self):\n    return self._mlp_config.layer_sizes[-1]\n"
  },
  {
    "path": "projects/home/recap/model/model_and_loss.py",
    "content": "from typing import Callable, Optional, List\nfrom tml.projects.home.recap.embedding import config as embedding_config_mod\nimport torch\nfrom absl import logging\n\n\nclass ModelAndLoss(torch.nn.Module):\n  def __init__(\n    self,\n    model,\n    loss_fn: Callable,\n    stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,\n  ) -> None:\n    \"\"\"\n    Args:\n      model: torch module to wrap.\n      loss_fn: Function for calculating loss, should accept logits and labels.\n      straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.\n    \"\"\"\n    super().__init__()\n    self.model = model\n    self.loss_fn = loss_fn\n    self.stratifiers = stratifiers\n\n  def forward(self, batch: \"RecapBatch\"):  # type: ignore[name-defined]\n    \"\"\"Runs model forward and calculates loss according to given loss_fn.\n\n    NOTE: The input signature here needs to be a Pipelineable object for\n    prefetching purposes during training using torchrec's pipeline.  However\n    the underlying model signature needs to be exportable to onnx, requiring\n    generic python types.  see https://pytorch.org/docs/stable/onnx.html#types.\n\n    \"\"\"\n    outputs = self.model(\n      continuous_features=batch.continuous_features,\n      binary_features=batch.binary_features,\n      discrete_features=batch.discrete_features,\n      sparse_features=batch.sparse_features,\n      user_embedding=batch.user_embedding,\n      user_eng_embedding=batch.user_eng_embedding,\n      author_embedding=batch.author_embedding,\n      labels=batch.labels,\n      weights=batch.weights,\n    )\n    losses = self.loss_fn(outputs[\"logits\"], batch.labels.float(), batch.weights.float())\n\n    if self.stratifiers:\n      logging.info(f\"***** Adding stratifiers *****\\n {self.stratifiers}\")\n      outputs[\"stratifiers\"] = {}\n      for stratifier in self.stratifiers:\n        outputs[\"stratifiers\"][stratifier.name] = batch.discrete_features[:, stratifier.index]\n\n    # In general, we can have a large number of losses returned by our loss function.\n    if isinstance(losses, dict):\n      return losses[\"loss\"], {\n        **outputs,\n        **losses,\n        \"labels\": batch.labels,\n        \"weights\": batch.weights,\n      }\n    else:  # Assume that this is a float.\n      return losses, {\n        **outputs,\n        \"loss\": losses,\n        \"labels\": batch.labels,\n        \"weights\": batch.weights,\n      }\n"
  },
  {
    "path": "projects/home/recap/model/numeric_calibration.py",
    "content": "import torch\n\n\nclass NumericCalibration(torch.nn.Module):\n  def __init__(\n    self,\n    pos_downsampling_rate: float,\n    neg_downsampling_rate: float,\n  ):\n    super().__init__()\n\n    # Using buffer to make sure they are on correct device (and not moved every time).\n    # Will also be part of state_dict.\n    self.register_buffer(\n      \"ratio\", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True\n    )\n\n  def forward(self, probs: torch.Tensor):\n    return probs * self.ratio / (1.0 - probs + (self.ratio * probs))\n"
  },
  {
    "path": "projects/home/recap/optimizer/__init__.py",
    "content": "from tml.projects.home.recap.optimizer.optimizer import build_optimizer\n"
  },
  {
    "path": "projects/home/recap/optimizer/config.py",
    "content": "\"\"\"Optimization configurations for models.\"\"\"\n\nimport typing\n\nimport tml.core.config as base_config\nimport tml.optimizers.config as optimizers_config_mod\n\nimport pydantic\n\n\nclass RecapAdamConfig(base_config.BaseConfig):\n  beta_1: float = 0.9  # Momentum term.\n  beta_2: float = 0.999  # Exponential weighted decay factor.\n  epsilon: float = 1e-7  # Numerical stability in denominator.\n\n\nclass MultiTaskLearningRates(base_config.BaseConfig):\n  tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(\n    description=\"Learning rates for different towers of the model.\"\n  )\n\n  backbone_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(\n    None, description=\"Learning rate for backbone of the model.\"\n  )\n\n\nclass RecapOptimizerConfig(base_config.BaseConfig):\n  multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(\n    None, description=\"Multiple learning rates for different tasks.\", one_of=\"lr\"\n  )\n\n  single_task_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(\n    None, description=\"Single task learning rates\", one_of=\"lr\"\n  )\n\n  adam: RecapAdamConfig = pydantic.Field(one_of=\"optimizer\")\n"
  },
  {
    "path": "projects/home/recap/optimizer/optimizer.py",
    "content": "\"\"\"Build optimizers and learning rate schedules.\"\"\"\nimport bisect\nfrom collections import defaultdict\nimport functools\nimport math\nimport typing\nfrom typing import Optional\nimport warnings\n\n# from large_embeddings.config import EmbeddingOptimizerConfig\nfrom tml.projects.home.recap import model as model_mod\nfrom tml.optimizers import config\nfrom tml.optimizers import compute_lr\nfrom absl import logging  # type: ignore[attr-defined]\n\nimport torch\nfrom torchrec.optim import keyed\n\n\n_DEFAULT_LR = 24601.0  # NaN the model if we're not using the learning rate.\n_BACKBONE = \"backbone\"\n_DENSE_EMBEDDINGS = \"dense_ebc\"\n\n\nclass RecapLRShim(torch.optim.lr_scheduler._LRScheduler):\n  \"\"\"Shim to get learning rates into a LRScheduler.\n\n  This adheres to the torch.optim scheduler API and can be plugged anywhere that\n  e.g. exponential decay can be used.\n\n  \"\"\"\n\n  def __init__(\n    self,\n    optimizer,\n    lr_dict: typing.Dict[str, config.LearningRate],\n    emb_learning_rate,\n    last_epoch=-1,\n    verbose=False,\n  ):\n    self.optimizer = optimizer\n    self.lr_dict = lr_dict\n    self.group_names = list(self.lr_dict.keys())\n    self.emb_learning_rate = emb_learning_rate\n\n    # We handle sparse LR scheduling separately, so only validate LR groups against dense param groups\n    num_dense_param_groups = sum(\n      1\n      for _, _optim in optimizer._optims\n      for _ in _optim.param_groups\n      if isinstance(_optim, keyed.KeyedOptimizerWrapper)\n    )\n    if num_dense_param_groups != len(lr_dict):\n      raise ValueError(\n        f\"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}.\"\n      )\n    super().__init__(optimizer, last_epoch, verbose)\n\n  def get_lr(self):\n    if not self._get_lr_called_within_step:\n      warnings.warn(\n        \"To get the last learning rate computed by the scheduler, \" \"please use `get_last_lr()`.\",\n        UserWarning,\n      )\n    return self._get_closed_form_lr()\n\n  def _get_closed_form_lr(self):\n    learning_rates = []\n\n    for lr_config in self.lr_dict.values():\n      learning_rates.append(compute_lr(lr_config, self.last_epoch))\n    # WARNING: The order of appending is important.\n    if self.emb_learning_rate:\n      learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch))\n    return learning_rates\n\n\ndef build_optimizer(\n  model: torch.nn.Module,\n  optimizer_config: config.OptimizerConfig,\n  emb_optimizer_config: None = None,  # Optional[EmbeddingOptimizerConfig] = None,\n):\n  \"\"\"Builds an optimizer and scheduler.\n\n  Args:\n    model: A torch model, probably with DDP/DMP.\n    optimizer_config: An OptimizerConfig object that specifies learning rates per tower.\n\n  Returns:\n    A torch.optim instance, and a scheduler instance.\n  \"\"\"\n  optimizer_fn = functools.partial(\n    torch.optim.Adam,\n    lr=_DEFAULT_LR,\n    betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2),\n    eps=optimizer_config.adam.epsilon,\n    maximize=False,\n  )\n  if optimizer_config.multi_task_learning_rates:\n    logging.info(\"***** Parameter groups for optimization *****\")\n    # Importantly, we preserve insertion order in dictionaries here.\n    parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict)\n    added_parameters: typing.Set[str] = set()\n    for task in optimizer_config.multi_task_learning_rates.tower_learning_rates:\n      for name, parameter in model.named_parameters():\n        if f\".{model_mod.sanitize(task)}.\" in name:\n          parameter_groups[task][name] = parameter\n          logging.info(f\"{task}: {name}\")\n          if name in added_parameters:\n            raise ValueError(f\"Parameter {name} matched multiple tasks.\")\n          added_parameters.add(name)\n\n    for name, parameter in model.named_parameters():\n      if name not in added_parameters and \"embedding_bags\" not in name:\n        parameter_groups[_BACKBONE][name] = parameter\n        added_parameters.add(name)\n        logging.info(f\"{_BACKBONE}: {name}\")\n\n    for name, parameter in model.named_parameters():\n      if name not in added_parameters and \"embedding_bags\" in name:\n        parameter_groups[_DENSE_EMBEDDINGS][name] = parameter\n        logging.info(f\"{_DENSE_EMBEDDINGS}: {name}\")\n\n    all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy()\n    if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None:\n      all_learning_rates[\n        _BACKBONE\n      ] = optimizer_config.multi_task_learning_rates.backbone_learning_rate\n    if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config:\n      all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy()\n  else:\n    parameter_groups = dict(model.named_parameters())\n    all_learning_rates = {\"single_task\": optimizer_config.single_task_learning_rate}\n\n  optimizers = [\n    keyed.KeyedOptimizerWrapper(param_group, optimizer_fn)\n    for param_name, param_group in parameter_groups.items()\n    if param_name != _DENSE_EMBEDDINGS\n  ]\n  # Making EBC optimizer to be SGD to match fused optimiser\n  if _DENSE_EMBEDDINGS in parameter_groups:\n    optimizers.append(\n      keyed.KeyedOptimizerWrapper(\n        parameter_groups[_DENSE_EMBEDDINGS],\n        functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False),\n      )\n    )\n\n  if not parameter_groups.keys() == all_learning_rates.keys():\n    raise ValueError(\"Learning rates do not match optimizers\")\n\n  # If the optimiser is dense, model.fused_optimizer will be empty (but not None)\n  emb_learning_rate = None\n  if hasattr(model, \"fused_optimizer\") and model.fused_optimizer.optimizers:\n    logging.info(f\"Model fused optimiser: {model.fused_optimizer}\")\n    optimizers.append(model.fused_optimizer)\n    if emb_optimizer_config:\n      emb_learning_rate = emb_optimizer_config.learning_rate.copy()\n    else:\n      raise ValueError(\"Fused kernel exists, but LR is not set\")\n  logging.info(f\"***** Combining optimizers: {optimizers} *****\")\n  optimizer = keyed.CombinedOptimizer(optimizers)\n  scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate)\n  logging.info(f\"***** Combined optimizer after init: {optimizer} *****\")\n\n  return optimizer, scheduler\n"
  },
  {
    "path": "projects/home/recap/script/create_random_data.sh",
    "content": "#!/usr/bin/env bash\n\n# Runs from inside venv\n\nrm -rf $HOME/tmp/runs/recap_local_random_data\npython -m tml.machines.is_venv || exit 1\nexport TML_BASE=\"$(git rev-parse --show-toplevel)\"\n\nmkdir -p $HOME/tmp/recap_local_random_data\npython projects/home/recap/data/generate_random_data.py --config_path $(pwd)/projects/home/recap/config/local_prod.yaml\n"
  },
  {
    "path": "projects/home/recap/script/run_local.sh",
    "content": "#!/usr/bin/env bash\n\n# Runs from inside venv\nrm -rf $HOME/tmp/runs/recap_local_debug\nmkdir -p $HOME/tmp/runs/recap_local_debug\npython -m tml.machines.is_venv || exit 1\nexport TML_BASE=\"$(git rev-parse --show-toplevel)\"\n\ntorchrun \\\n  --standalone \\\n  --nnodes 1 \\\n  --nproc_per_node 1 \\\n  projects/home/recap/main.py \\\n  --config_path $(pwd)/projects/home/recap/config/local_prod.yaml \\\n  $@\n"
  },
  {
    "path": "projects/twhin/README.md",
    "content": "Twhin in torchrec\n\nThis project contains code for pretraining dense vector embedding features for Twitter entities. Within Twitter, these embeddings are used for candidate retrieval and as model features in a variety of recommender system models.\n\nWe obtain entity embeddings based on a variety of graph data within Twitter such as:\n  \"User follows User\"\n  \"User favorites Tweet\"\n  \"User clicks Advertisement\"\n\nWhile we cannot release the graph data used to train TwHIN embeddings due to privacy restrictions, heavily subsampled, anonymized open-sourced graph data can used:\nhttps://huggingface.co/datasets/Twitter/TwitterFollowGraph\nhttps://huggingface.co/datasets/Twitter/TwitterFaveGraph\n\nThe code expects parquet files with three columns: lhs, rel, rhs that refer to the vocab index of the left-hand-side node, relation type, and right-hand-side node of each edge in a graph respectively.\n\nThe location of the data must be specified in the configuration yaml files in projects/twhin/configs.\n\n\nWorkflow\n========\n- Build local development images `./scripts/build_images.sh`\n- Run with `./scripts/docker_run.sh`\n- Iterate in image with `./scripts/idocker.sh`\n- Run tests with `./scripts/docker_test.sh`\n"
  },
  {
    "path": "projects/twhin/config/local.yaml",
    "content": "runtime:\n  enable_amp: false\ntraining:\n  save_dir: \"/tmp/model\"\n  num_train_steps: 100000\n  checkpoint_every_n: 100000\n  train_log_every_n: 10\n  num_eval_steps: 1000\n  eval_log_every_n: 500\n  eval_timeout_in_s: 10000\n  num_epochs: 5\nmodel:\n  translation_optimizer:\n    sgd:\n      lr: 0.05\n    learning_rate:\n      constant: 0.05\n  embeddings:\n    tables:\n      - name: user\n        num_embeddings: 424_241\n        embedding_dim: 4\n        data_type: fp32\n        optimizer:\n          sgd:\n            lr: 0.01\n          learning_rate:\n            constant: 0.01\n      - name: tweet\n        num_embeddings: 72_543\n        embedding_dim: 4\n        data_type: fp32\n        optimizer:\n          sgd:\n            lr: 0.005\n          learning_rate:\n            constant: 0.005\n  relations:\n    - name: fav\n      lhs: user\n      rhs: tweet\n      operator: translation\n    - name: reply\n      lhs: user\n      rhs: tweet\n      operator: translation\n    - name: retweet\n      lhs: user\n      rhs: tweet\n      operator: translation\n    - name: magic_recs\n      lhs: user\n      rhs: tweet\n      operator: translation\ntrain_data:\n  data_root: \"gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*\"\n  per_replica_batch_size: 500\n  global_negatives: 0\n  in_batch_negatives: 10\n  limit: 9990\nvalidation_data:\n  data_root: \"gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*\"\n  per_replica_batch_size: 500\n  global_negatives: 0\n  in_batch_negatives: 10\n  limit: 10\n  offset: 9990\n"
  },
  {
    "path": "projects/twhin/config.py",
    "content": "from tml.core.config import base_config\nfrom tml.projects.twhin.data.config import TwhinDataConfig\nfrom tml.projects.twhin.models.config import TwhinModelConfig\nfrom tml.core.config.training import RuntimeConfig, TrainingConfig\n\nimport pydantic\n\n\nclass TwhinConfig(base_config.BaseConfig):\n  runtime: RuntimeConfig = pydantic.Field(RuntimeConfig())\n  training: TrainingConfig = pydantic.Field(TrainingConfig())\n  model: TwhinModelConfig\n  train_data: TwhinDataConfig\n  validation_data: TwhinDataConfig\n"
  },
  {
    "path": "projects/twhin/data/config.py",
    "content": "from tml.core.config import base_config\n\nimport pydantic\n\n\nclass TwhinDataConfig(base_config.BaseConfig):\n  data_root: str\n  per_replica_batch_size: pydantic.PositiveInt\n  global_negatives: int\n  in_batch_negatives: int\n  limit: pydantic.PositiveInt\n  offset: pydantic.PositiveInt = pydantic.Field(\n    None, description=\"The offset to start reading from.\"\n  )\n"
  },
  {
    "path": "projects/twhin/data/data.py",
    "content": "from tml.projects.twhin.data.config import TwhinDataConfig\nfrom tml.projects.twhin.models.config import TwhinModelConfig\nfrom tml.projects.twhin.data.edges import EdgesDataset\n\n\ndef create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):\n  tables = model_config.embeddings.tables\n  table_sizes = {table.name: table.num_embeddings for table in tables}\n  relations = model_config.relations\n\n  pos_batch_size = data_config.per_replica_batch_size\n\n  return EdgesDataset(\n    file_pattern=data_config.data_root,\n    relations=relations,\n    table_sizes=table_sizes,\n    batch_size=pos_batch_size,\n  )\n"
  },
  {
    "path": "projects/twhin/data/edges.py",
    "content": "from collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Tuple\n\nfrom tml.common.batch import DataclassBatch\nfrom tml.reader.dataset import Dataset\nfrom tml.projects.twhin.models.config import Relation\n\nimport numpy as np\nimport pyarrow as pa\nimport pyarrow.compute as pc\nimport torch\nfrom torchrec.sparse.jagged_tensor import KeyedJaggedTensor\n\n\n@dataclass\nclass EdgeBatch(DataclassBatch):\n  nodes: KeyedJaggedTensor\n  labels: torch.Tensor\n  rels: torch.Tensor\n  weights: torch.Tensor\n\n\nclass EdgesDataset(Dataset):\n  rng = np.random.default_rng()\n\n  def __init__(\n    self,\n    file_pattern: str,\n    table_sizes: Dict[str, int],\n    relations: List[Relation],\n    lhs_column_name: str = \"lhs\",\n    rhs_column_name: str = \"rhs\",\n    rel_column_name: str = \"rel\",\n    **dataset_kwargs\n  ):\n    self.batch_size = dataset_kwargs[\"batch_size\"]\n\n    self.table_sizes = table_sizes\n    self.num_tables = len(table_sizes)\n    self.table_names = list(table_sizes.keys())\n\n    self.relations = relations\n    self.relations_t = torch.tensor(\n      [\n        [self.table_names.index(relation.lhs), self.table_names.index(relation.rhs)]\n        for relation in relations\n      ]\n    )\n\n    self.lhs_column_name = lhs_column_name\n    self.rhs_column_name = rhs_column_name\n    self.rel_column_name = rel_column_name\n    self.label_column_name = \"label\"\n\n    super().__init__(file_pattern=file_pattern, **dataset_kwargs)\n\n  def pa_to_batch(self, batch: pa.RecordBatch):\n    lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())\n    rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())\n    rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())\n    label = torch.from_numpy(batch.column(self.label_column_name).to_numpy())\n\n    nodes = self._to_kjt(lhs, rhs, rel)\n    return EdgeBatch(\n      nodes=nodes,\n      rels=rel,\n      labels=label,\n      weights=torch.ones(batch.num_rows),\n    )\n\n  def _to_kjt(\n    self, lhs: torch.Tensor, rhs: torch.Tensor, rel: torch.Tensor\n  ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:\n\n    \"\"\"Process edges that contain lhs index, rhs index, relation index.\n    Example:\n\n    ```\n    tables = [\"f0\", \"f1\", \"f2\", \"f3\"]\n    relations = [[\"f0\", \"f1\"], [\"f1\", \"f2\"], [\"f1\", \"f0\"], [\"f2\", \"f1\"], [\"f0\", \"f2\"]]\n    self.relations_t = torch.Tensor([[0, 1], [1, 2], [1, 0], [2, 1], [0, 2]])\n    lhs = [1, 6, 3, 1, 8]\n    rhs = [6, 3, 4, 4, 9]\n    rel = [0, 2, 1, 3, 4]\n\n    This corresponds to the following \"edges\":\n    edges = [\n      {\"lhs\": 1, \"rhs\": 6, \"relation\": [\"f0\", \"f1\"]},\n      {\"lhs\": 6, \"rhs\": 3, \"relation\": [\"f1\", \"f0\"]},\n      {\"lhs\": 3, \"rhs\": 4, \"relation\": [\"f1\", \"f2\"]},\n      {\"lhs\": 1, \"rhs\": 4, \"relation\": [\"f2\", \"f1\"]},\n      {\"lhs\": 8, \"rhs\": 9, \"relation\": [\"f0\", \"f2\"]},\n    ]\n    ```\n\n    Returns a KeyedJaggedTensor used to look up all embeddings.\n\n    Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`.\n    This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`.\n\n    For the example above:\n    ```\n    lookups = tensor([\n      [0., 1.],\n      [1., 6.],\n      [1., 6.],\n      [0., 3.],\n      [1., 3.],\n      [2., 4.],\n      [2., 1.],\n      [1., 4.],\n      [0., 8.],\n      [2., 9.]\n    ])\n\n    kjt = KeyedJaggedTensor(\n      features=[\"f0\", \"f1\", \"f2\"]\n      values=[\n        1, 3, 8,      # f0\n        6, 6, 3, 4,   # f1\n        4, 1, 9       # f2\n      ]\n      lengths=[\n        1, 0, 0, 1, 0, 0, 0, 0, 1, 0,  # f0\n        0, 1, 1, 0, 1, 0, 0, 1, 0, 0,  # f1\n        0, 0, 0, 0, 0, 1, 1, 0, 0, 1,  # f2\n    )\n    ```\n\n    Note:\n      - values = [values for f0] + [values for f1] + [values for f2]\n      - lengths are always 0 or 1, and sum(lengths) = len(values) = 2 * bsz\n    \"\"\"\n    lookups = torch.concat((lhs[:, None], self.relations_t[rel], rhs[:, None]), dim=1)\n    index = torch.LongTensor([1, 0, 2, 3])\n    lookups = lookups[:, index].reshape(2 * self.batch_size, 2)\n\n    # values is just the row indices into each table, ordered by the table indices\n    _, indices = torch.sort(lookups[:, 0], dim=0, stable=True)\n    values = lookups[indices][:, 1].int()\n\n    # lengths[table_idx * batch_size + i] == whether the ith lookup is for the table with index table_idx\n    lengths = torch.arange(self.num_tables)[:, None].eq(lookups[:, 0])\n    lengths = lengths.reshape(-1).int()\n\n    return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)\n\n  def to_batches(self):\n    ds = super().to_batches()\n    batch_size = self._dataset_kwargs[\"batch_size\"]\n\n    names = [\n      self.lhs_column_name,\n      self.rhs_column_name,\n      self.rel_column_name,\n      self.label_column_name,\n    ]\n    for _, batch in enumerate(ds):\n      # Pass along positive edges\n      lhs = batch.column(self.lhs_column_name)\n      rhs = batch.column(self.rhs_column_name)\n      rel = batch.column(self.rel_column_name)\n      label = pa.array(np.ones(batch_size, dtype=np.int64))\n\n      yield pa.RecordBatch.from_arrays(\n        arrays=[lhs, rhs, rel, label],\n        names=names,\n      )\n"
  },
  {
    "path": "projects/twhin/data/test_data.py",
    "content": "import pytest\nfrom unittest.mock import Mock\n\n\ndef test_create_dataset():\n  pass\n"
  },
  {
    "path": "projects/twhin/data/test_edges.py",
    "content": "\"\"\"Tests edges dataset functionality.\"\"\"\n\nfrom unittest.mock import patch\nimport os\nimport tempfile\n\nfrom tml.projects.twhin.data.edges import EdgesDataset\nfrom tml.projects.twhin.models.config import Relation\n\nfrom fsspec.implementations.local import LocalFileSystem\nimport numpy as np\nimport pyarrow as pa\nimport pyarrow.compute as pc\nimport pyarrow.parquet as pq\nimport torch\n\n\nTABLE_SIZES = {\"user\": 16, \"author\": 32}\nRELATIONS = [\n  Relation(name=\"fav\", lhs=\"user\", rhs=\"author\"),\n  Relation(name=\"engaged_with_reply\", lhs=\"author\", rhs=\"user\"),\n]\n\n\ndef test_gen():\n  import os\n  import tempfile\n\n  from fsspec.implementations.local import LocalFileSystem\n  import pyarrow as pa\n  import pyarrow.parquet as pq\n\n  lhs = pa.array(np.arange(4))\n  rhs = pa.array(np.flip(np.arange(4)))\n  rel = pa.array([0, 1, 0, 0])\n  names = [\"lhs\", \"rhs\", \"rel\"]\n\n  with tempfile.TemporaryDirectory() as tmpdir:\n    table = pa.Table.from_arrays([lhs, rhs, rel], names=names)\n    writer = pq.ParquetWriter(\n      os.path.join(tmpdir, \"example.parquet\"),\n      table.schema,\n    )\n    writer.write_table(table)\n    writer.close()\n\n    ds = EdgesDataset(\n      file_pattern=os.path.join(tmpdir, \"*\"),\n      table_sizes=TABLE_SIZES,\n      relations=RELATIONS,\n      batch_size=4,\n    )\n    ds.FS = LocalFileSystem()\n\n    dl = ds.dataloader()\n    batch = next(iter(dl))\n\n    # labels should be positive\n    labels = batch.labels\n    assert (labels[:4] == 1).sum() == 4\n\n    # make sure positive examples are what we expect\n    kjt_values = batch.nodes.values()\n    users, authors = torch.split(kjt_values, 4, dim=0)\n    assert torch.equal(users[:4], torch.tensor([0, 2, 2, 3]))\n    assert torch.equal(authors[:4], torch.tensor([3, 1, 1, 0]))\n"
  },
  {
    "path": "projects/twhin/machines.yaml",
    "content": "chief: &gpu\n  mem: 1.4Ti\n  cpu: 24\n  num_accelerators: 16\n  accelerator_type: a100\ndataset_dispatcher:\n  mem: 2Gi\n  cpu: 2\nnum_dataset_workers: 4\ndataset_worker:\n  mem: 14Gi\n  cpu: 2\n"
  },
  {
    "path": "projects/twhin/metrics.py",
    "content": "import torch\nimport torchmetrics as tm\n\nimport tml.core.metrics as core_metrics\n\n\ndef create_metrics(\n  device: torch.device,\n):\n  metrics = dict()\n  metrics.update(\n    {\n      \"AUC\": core_metrics.Auc(128),\n    }\n  )\n  metrics = tm.MetricCollection(metrics).to(device)\n  return metrics\n"
  },
  {
    "path": "projects/twhin/models/config.py",
    "content": "import typing\nimport enum\n\nfrom tml.common.modules.embedding.config import LargeEmbeddingsConfig\nfrom tml.core.config import base_config\nfrom tml.optimizers.config import OptimizerConfig\n\nimport pydantic\nfrom pydantic import validator\n\n\nclass TwhinEmbeddingsConfig(LargeEmbeddingsConfig):\n  @validator(\"tables\")\n  def embedding_dims_match(cls, tables):\n    embedding_dim = tables[0].embedding_dim\n    data_type = tables[0].data_type\n    for table in tables:\n      assert table.embedding_dim == embedding_dim, \"Embedding dimensions for all nodes must match.\"\n      assert table.data_type == data_type, \"Data types for all nodes must match.\"\n    return tables\n\n\nclass Operator(str, enum.Enum):\n  TRANSLATION = \"translation\"\n\n\nclass Relation(pydantic.BaseModel):\n  \"\"\"graph relationship properties and operator\"\"\"\n\n  name: str = pydantic.Field(..., description=\"Relationship name.\")\n  lhs: str = pydantic.Field(\n    ...,\n    description=\"Name of the entity on the left-hand-side of this relation. Must match a table name.\",\n  )\n  rhs: str = pydantic.Field(\n    ...,\n    description=\"Name of the entity on the right-hand-side of this relation. Must match a table name.\",\n  )\n  operator: Operator = pydantic.Field(\n    Operator.TRANSLATION, description=\"Transformation to apply to lhs embedding before dot product.\"\n  )\n\n\nclass TwhinModelConfig(base_config.BaseConfig):\n  embeddings: TwhinEmbeddingsConfig\n  relations: typing.List[Relation]\n  translation_optimizer: OptimizerConfig\n\n  @validator(\"relations\", each_item=True)\n  def valid_node_types(cls, relation, values, **kwargs):\n    table_names = [table.name for table in values[\"embeddings\"].tables]\n    assert relation.lhs in table_names, f\"Invalid lhs node type: {relation.lhs}\"\n    assert relation.rhs in table_names, f\"Invalid rhs node type: {relation.rhs}\"\n    return relation\n"
  },
  {
    "path": "projects/twhin/models/models.py",
    "content": "from typing import Callable\nimport math\n\nfrom tml.projects.twhin.data.edges import EdgeBatch\nfrom tml.projects.twhin.models.config import TwhinModelConfig\nfrom tml.projects.twhin.data.config import TwhinDataConfig\nfrom tml.common.modules.embedding.embedding import LargeEmbeddings\nfrom tml.optimizers.optimizer import get_optimizer_class\nfrom tml.optimizers.config import get_optimizer_algorithm_config\n\nimport torch\nfrom torch import nn\nfrom torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward\n\n\nclass TwhinModel(nn.Module):\n  def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):\n    super().__init__()\n    self.batch_size = data_config.per_replica_batch_size\n    self.table_names = [table.name for table in model_config.embeddings.tables]\n    self.large_embeddings = LargeEmbeddings(model_config.embeddings)\n    self.embedding_dim = model_config.embeddings.tables[0].embedding_dim\n    self.num_tables = len(model_config.embeddings.tables)\n    self.in_batch_negatives = data_config.in_batch_negatives\n    self.global_negatives = data_config.global_negatives\n    self.num_relations = len(model_config.relations)\n\n    # one bias per relation\n    self.all_trans_embs = torch.nn.parameter.Parameter(\n      torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))\n    )\n\n  def forward(self, batch: EdgeBatch):\n\n    # B x D\n    trans_embs = self.all_trans_embs.data[batch.rels]\n\n    # KeyedTensor\n    outs = self.large_embeddings(batch.nodes)\n\n    # 2B x TD\n    x = outs.values()\n\n    # 2B x T x D\n    x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)\n\n    # 2B x D\n    x = torch.sum(x, 1)\n\n    # B x 2 x D\n    x = x.reshape(self.batch_size, 2, self.embedding_dim)\n\n    # translated\n    translated = x[:, 1, :] + trans_embs\n\n    negs = []\n    if self.in_batch_negatives:\n      # construct dot products for negatives via matmul\n      for relation in range(self.num_relations):\n        rel_mask = batch.rels == relation\n        rel_count = rel_mask.sum()\n\n        if not rel_count:\n          continue\n\n        # R x D\n        lhs_matrix = x[rel_mask, 0, :]\n        rhs_matrix = x[rel_mask, 1, :]\n\n        lhs_perm = torch.randperm(lhs_matrix.shape[0])\n        # repeat until we have enough negatives\n        lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))\n        lhs_indices = lhs_perm[: self.in_batch_negatives]\n        sampled_lhs = lhs_matrix[lhs_indices]\n\n        rhs_perm = torch.randperm(rhs_matrix.shape[0])\n        # repeat until we have enough negatives\n        rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))\n        rhs_indices = rhs_perm[: self.in_batch_negatives]\n        sampled_rhs = rhs_matrix[rhs_indices]\n\n        # RS\n        negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))\n        negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))\n\n        negs.append(negs_lhs)\n        negs.append(negs_rhs)\n\n    # dot product for positives\n    x = (x[:, 0, :] * translated).sum(-1)\n\n    # concat positives and negatives\n    x = torch.cat([x, *negs])\n    return {\n      \"logits\": x,\n      \"probabilities\": torch.sigmoid(x),\n    }\n\n\ndef apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):\n  for table in model_config.embeddings.tables:\n    optimizer_class = get_optimizer_class(table.optimizer)\n    optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()\n    params = [\n      param\n      for name, param in model.large_embeddings.ebc.named_parameters()\n      if (name.startswith(f\"embedding_bags.{table.name}\"))\n    ]\n    apply_optimizer_in_backward(\n      optimizer_class=optimizer_class,\n      params=params,\n      optimizer_kwargs=optimizer_kwargs,\n    )\n\n  return model\n\n\nclass TwhinModelAndLoss(torch.nn.Module):\n  def __init__(\n    self,\n    model,\n    loss_fn: Callable,\n    data_config: TwhinDataConfig,\n    device: torch.device,\n  ) -> None:\n    \"\"\"\n    Args:\n      model: torch module to wrap.\n      loss_fn: Function for calculating loss, should accept logits and labels.\n    \"\"\"\n    super().__init__()\n    self.model = model\n    self.loss_fn = loss_fn\n    self.batch_size = data_config.per_replica_batch_size\n    self.in_batch_negatives = data_config.in_batch_negatives\n    self.device = device\n\n  def forward(self, batch: \"RecapBatch\"):  # type: ignore[name-defined]\n    \"\"\"Runs model forward and calculates loss according to given loss_fn.\n\n    NOTE: The input signature here needs to be a Pipelineable object for\n    prefetching purposes during training using torchrec's pipeline.  However\n    the underlying model signature needs to be exportable to onnx, requiring\n    generic python types.  see https://pytorch.org/docs/stable/onnx.html#types.\n\n    \"\"\"\n    outputs = self.model(batch)\n    logits = outputs[\"logits\"]\n\n    num_negatives = 2 * self.batch_size * self.in_batch_negatives\n    num_positives = self.batch_size\n\n    neg_weight = float(num_positives) / num_negatives\n\n    labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])\n\n    weights = torch.cat(\n      [batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]\n    )\n\n    losses = self.loss_fn(logits, labels, weights)\n\n    outputs.update(\n      {\n        \"loss\": losses,\n        \"labels\": labels,\n        \"weights\": weights,\n      }\n    )\n\n    # Allow multiple losses.\n    return losses, outputs\n"
  },
  {
    "path": "projects/twhin/models/test_models.py",
    "content": "from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig\nfrom tml.projects.twhin.data.config import TwhinDataConfig\nfrom tml.common.modules.embedding.config import DataType, EmbeddingBagConfig\nfrom tml.optimizers.config import OptimizerConfig, SgdConfig\nfrom tml.model import maybe_shard_model\nfrom tml.projects.twhin.models.models import apply_optimizers, TwhinModel\nfrom tml.projects.twhin.models.config import Operator, Relation\nfrom tml.common.testing_utils import mock_pg\n\nimport torch\nimport torch.nn.functional as F\nfrom pydantic import ValidationError\nimport pytest\n\n\nNUM_EMBS = 10_000\nEMB_DIM = 128\n\n\ndef twhin_model_config() -> TwhinModelConfig:\n  sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))\n  sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))\n\n  table0 = EmbeddingBagConfig(\n    name=\"table0\",\n    num_embeddings=NUM_EMBS,\n    embedding_dim=EMB_DIM,\n    optimizer=sgd_config_0,\n    data_type=DataType.FP32,\n  )\n  table1 = EmbeddingBagConfig(\n    name=\"table1\",\n    num_embeddings=NUM_EMBS,\n    embedding_dim=EMB_DIM,\n    optimizer=sgd_config_1,\n    data_type=DataType.FP32,\n  )\n  embeddings_config = TwhinEmbeddingsConfig(\n    tables=[table0, table1],\n  )\n\n  model_config = TwhinModelConfig(\n    embeddings=embeddings_config,\n    translation_optimizer=sgd_config_0,\n    relations=[\n      Relation(name=\"rel0\", lhs=\"table0\", rhs=\"table1\", operator=Operator.TRANSLATION),\n      Relation(name=\"rel1\", lhs=\"table1\", rhs=\"table0\", operator=Operator.TRANSLATION),\n    ],\n  )\n\n  return model_config\n\n\ndef twhin_data_config() -> TwhinDataConfig:\n  data_config = TwhinDataConfig(\n    data_root=\"/\",\n    per_replica_batch_size=10,\n    global_negatives=10,\n    in_batch_negatives=10,\n    limit=1,\n    offset=1,\n  )\n\n  return data_config\n\n\ndef test_twhin_model():\n  model_config = twhin_model_config()\n  loss_fn = F.binary_cross_entropy_with_logits\n\n  with mock_pg():\n    data_config = twhin_data_config()\n    model = TwhinModel(model_config=model_config, data_config=data_config)\n\n    apply_optimizers(model, model_config)\n\n    for tensor in model.state_dict().values():\n      if tensor.size() == (NUM_EMBS, EMB_DIM):\n        assert str(tensor.device) == \"meta\"\n      else:\n        assert str(tensor.device) == \"cpu\"\n\n    model = maybe_shard_model(model, device=torch.device(\"cpu\"))\n\n\ndef test_unequal_dims():\n  sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))\n  sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))\n  table0 = EmbeddingBagConfig(\n    name=\"table0\",\n    num_embeddings=10_000,\n    embedding_dim=128,\n    optimizer=sgd_config_1,\n    data_type=DataType.FP32,\n  )\n  table1 = EmbeddingBagConfig(\n    name=\"table1\",\n    num_embeddings=10_000,\n    embedding_dim=64,\n    optimizer=sgd_config_2,\n    data_type=DataType.FP32,\n  )\n\n  with pytest.raises(ValidationError):\n    _ = TwhinEmbeddingsConfig(\n      tables=[table0, table1],\n    )\n"
  },
  {
    "path": "projects/twhin/optimizer.py",
    "content": "import functools\n\nfrom tml.projects.twhin.models.config import TwhinModelConfig\nfrom tml.projects.twhin.models.models import TwhinModel\nfrom tml.optimizers.optimizer import get_optimizer_class, LRShim\nfrom tml.optimizers.config import get_optimizer_algorithm_config, LearningRate\nfrom tml.ml_logging.torch_logging import logging\n\nfrom torchrec.optim.optimizers import in_backward_optimizer_filter\nfrom torchrec.optim import keyed\n\n\nFUSED_OPT_KEY = \"fused_opt\"\nTRANSLATION_OPT_KEY = \"operator_opt\"\n\n\ndef _lr_from_config(optimizer_config):\n  if optimizer_config.learning_rate is not None:\n    return optimizer_config.learning_rate\n  else:\n    # treat None as constant lr\n    lr_value = get_optimizer_algorithm_config(optimizer_config).lr\n    return LearningRate(constant=lr_value)\n\n\ndef build_optimizer(model: TwhinModel, config: TwhinModelConfig):\n  \"\"\"Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.\n\n  Args:\n    model: TwhinModel to build optimizer for.\n    config: TwhinConfig for model.\n\n  Returns:\n    Optimizer for model.\n  \"\"\"\n  translation_optimizer_fn = functools.partial(\n    get_optimizer_class(config.translation_optimizer),\n    **get_optimizer_algorithm_config(config.translation_optimizer).dict(),\n  )\n\n  translation_optimizer = keyed.KeyedOptimizerWrapper(\n    dict(in_backward_optimizer_filter(model.named_parameters())),\n    optim_factory=translation_optimizer_fn,\n  )\n\n  lr_dict = {}\n  for table in config.embeddings.tables:\n    lr_dict[table.name] = _lr_from_config(table.optimizer)\n  lr_dict[TRANSLATION_OPT_KEY] = _lr_from_config(config.translation_optimizer)\n\n  logging.info(f\"***** LR dict: {lr_dict} *****\")\n\n  logging.info(\n    f\"***** Combining fused optimizer {model.fused_optimizer} with operator optimizer: {translation_optimizer} *****\"\n  )\n  optimizer = keyed.CombinedOptimizer(\n    [\n      (FUSED_OPT_KEY, model.fused_optimizer),\n      (TRANSLATION_OPT_KEY, translation_optimizer),\n    ]\n  )\n\n  # scheduler = LRShim(optimizer, lr_dict)\n  scheduler = None\n\n  logging.info(f\"***** Combined optimizer after init: {optimizer} *****\")\n\n  return optimizer, scheduler\n"
  },
  {
    "path": "projects/twhin/run.py",
    "content": "from absl import app, flags\nimport json\nfrom typing import Optional\nimport os\nimport sys\n\nimport torch\n\n# isort: on\nfrom tml.common.device import setup_and_get_device\nfrom tml.common.utils import setup_configuration\nimport tml.core.custom_training_loop as ctl\nimport tml.machines.environment as env\nfrom tml.projects.twhin.models.models import apply_optimizers, TwhinModel, TwhinModelAndLoss\nfrom tml.model import maybe_shard_model\nfrom tml.projects.twhin.metrics import create_metrics\nfrom tml.projects.twhin.config import TwhinConfig\nfrom tml.projects.twhin.data.data import create_dataset\nfrom tml.projects.twhin.optimizer import build_optimizer\n\nfrom tml.ml_logging.torch_logging import logging\n\nimport torch.distributed as dist\nfrom torch.nn import functional as F\nfrom torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward\nfrom torchrec.distributed.model_parallel import get_module\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_bool(\"overwrite_save_dir\", False, \"Whether to clear preexisting save directories.\")\nflags.DEFINE_string(\"save_dir\", None, \"If provided, overwrites the save directory.\")\nflags.DEFINE_string(\"config_yaml_path\", None, \"Path to hyperparameters for model.\")\nflags.DEFINE_string(\"task\", None, \"Task to run if this is local. Overrides TF_CONFIG etc.\")\n\n\ndef run(\n  all_config: TwhinConfig,\n  save_dir: Optional[str] = None,\n):\n  train_dataset = create_dataset(all_config.train_data, all_config.model)\n\n  if env.is_reader():\n    train_dataset.serve()\n  if env.is_chief():\n    device = setup_and_get_device(tf_ok=False)\n    logging.info(f\"device: {device}\")\n    logging.info(f\"WORLD_SIZE: {dist.get_world_size()}\")\n\n    # validation_dataset = create_dataset(all_config.validation_data, all_config.model)\n\n    global_batch_size = all_config.train_data.per_replica_batch_size * dist.get_world_size()\n\n    metrics = create_metrics(device)\n\n    model = TwhinModel(all_config.model, all_config.train_data)\n    apply_optimizers(model, all_config.model)\n    model = maybe_shard_model(model, device=device)\n    optimizer, scheduler = build_optimizer(model=model, config=all_config.model)\n\n    loss_fn = F.binary_cross_entropy_with_logits\n    model_and_loss = TwhinModelAndLoss(\n      model, loss_fn, data_config=all_config.train_data, device=device\n    )\n\n    ctl.train(\n      model=model_and_loss,\n      optimizer=optimizer,\n      device=device,\n      save_dir=save_dir,\n      logging_interval=all_config.training.train_log_every_n,\n      train_steps=all_config.training.num_train_steps,\n      checkpoint_frequency=all_config.training.checkpoint_every_n,\n      dataset=train_dataset.dataloader(remote=False),\n      worker_batch_size=global_batch_size,\n      num_workers=0,\n      scheduler=scheduler,\n      initial_checkpoint_dir=all_config.training.initial_checkpoint_dir,\n      gradient_accumulation=all_config.training.gradient_accumulation,\n    )\n\n\ndef main(argv):\n  logging.info(\"Starting\")\n\n  logging.info(f\"parsing config from {FLAGS.config_yaml_path}...\")\n  all_config = setup_configuration(  # type: ignore[var-annotated]\n    TwhinConfig,\n    yaml_path=FLAGS.config_yaml_path,\n  )\n\n  run(\n    all_config,\n    save_dir=FLAGS.save_dir,\n  )\n\n\nif __name__ == \"__main__\":\n  app.run(main)\n"
  },
  {
    "path": "projects/twhin/scripts/docker_run.sh",
    "content": "#! /bin/sh\n\ndocker run -it --rm \\\n  -v $HOME/workspace/tml:/usr/src/app/tml \\\n  -v $HOME/.config:/root/.config \\\n  -w /usr/src/app \\\n  -e PYTHONPATH=\"/usr/src/app/\" \\\n  --network host \\\n  -e SPEC_TYPE=chief \\\n  local/torch \\\n  bash tml/projects/twhin/scripts/run_in_docker.sh\n"
  },
  {
    "path": "projects/twhin/scripts/run_in_docker.sh",
    "content": "#! /bin/sh\n\ntorchrun \\\n  --standalone \\\n  --nnodes 1 \\\n  --nproc_per_node 2 \\\n  /usr/src/app/tml/projects/twhin/run.py \\\n  --config_yaml_path=\"/usr/src/app/tml/projects/twhin/config/local.yaml\" \\\n  --save_dir=\"/some/save/dir\"\n"
  },
  {
    "path": "projects/twhin/test_optimizer.py",
    "content": "import pytest\nimport unittest\n\nfrom tml.projects.twhin.models.models import TwhinModel, apply_optimizers\nfrom tml.projects.twhin.models.test_models import twhin_model_config, twhin_data_config\nfrom tml.projects.twhin.optimizer import build_optimizer\nfrom tml.model import maybe_shard_model\nfrom tml.common.testing_utils import mock_pg\n\n\nimport torch\nfrom torch.nn import functional as F\n\n\ndef test_twhin_optimizer():\n  model_config = twhin_model_config()\n  data_config = twhin_data_config()\n\n  loss_fn = F.binary_cross_entropy_with_logits\n  with mock_pg():\n    model = TwhinModel(model_config, data_config)\n    apply_optimizers(model, model_config)\n    model = maybe_shard_model(model, device=torch.device(\"cpu\"))\n\n    optimizer, _ = build_optimizer(model, model_config)\n\n    # make sure there is one combined fused optimizer and one translation optimizer\n    assert len(optimizer.optimizers) == 2\n    fused_opt_tup, _ = optimizer.optimizers\n    _, fused_opt = fused_opt_tup\n\n    # make sure there are two tables for which the fused opt has parameters\n    assert len(fused_opt.param_groups) == 2\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.black]\nline-length = 100\ninclude = '\\.pyi?$'\nexclude = '''\n/(\n    \\.git\n  | \\.hg\n  | \\.pem\n  | \\.mypy_cache\n  | \\.tox\n  | \\.venv\n  | _build\n  | buck-out\n  | build\n  | dist\n)/\n'''\n"
  },
  {
    "path": "reader/__init__.py",
    "content": ""
  },
  {
    "path": "reader/dataset.py",
    "content": "\"\"\"Dataset to be overwritten that can work with or without distributed reading.\n\n- Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch.\n- Readers can be colocated or off trainer machines.\n\n\"\"\"\nimport abc\nimport functools\nimport random\nfrom typing import Optional\n\nfrom fsspec.implementations.local import LocalFileSystem\nimport pyarrow.dataset as pads\nimport pyarrow as pa\nimport pyarrow.parquet\nimport pyarrow.flight\nfrom pyarrow.ipc import IpcWriteOptions\nimport torch\n\nfrom tml.common.batch import DataclassBatch\nfrom tml.machines import environment as env\nimport tml.reader.utils as reader_utils\nfrom tml.common.filesystem import infer_fs\nfrom tml.ml_logging.torch_logging import logging\n\n\nclass _Reader(pa.flight.FlightServerBase):\n  \"\"\"Distributed reader flight server wrapping a dataset.\"\"\"\n\n  def __init__(self, location: str, ds: \"Dataset\"):\n    super().__init__(location=location)\n    self._location = location\n    self._ds = ds\n\n  def do_get(self, _, __):\n    # NB: An updated schema (to account for column selection) has to be given the stream.\n    schema = next(iter(self._ds.to_batches())).schema\n    batches = self._ds.to_batches()\n    return pa.flight.RecordBatchStream(\n      data_source=pa.RecordBatchReader.from_batches(\n        schema=schema,\n        batches=batches,\n      ),\n      options=IpcWriteOptions(use_threads=True),\n    )\n\n\nclass Dataset(torch.utils.data.IterableDataset):\n  LOCATION = \"grpc://0.0.0.0:2222\"\n\n  def __init__(self, file_pattern: str, **dataset_kwargs) -> None:\n    \"\"\"Specify batch size and column to select for.\n\n    Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.\n    \"\"\"\n    self._file_pattern = file_pattern\n    self._fs = infer_fs(self._file_pattern)\n    self._dataset_kwargs = dataset_kwargs\n    logging.info(f\"Using dataset_kwargs: {self._dataset_kwargs}\")\n    self._files = self._fs.glob(self._file_pattern)\n    assert len(self._files) > 0, f\"No files found at {self._file_pattern}\"\n    logging.info(f\"Found {len(self._files)} files: {', '.join(self._files[:4])}, ...\")\n    self._schema = pa.parquet.read_schema(self._files[0], filesystem=self._fs)\n    self._validate_columns()\n\n  def _validate_columns(self):\n    columns = set(self._dataset_kwargs.get(\"columns\", []))\n    wrong_columns = set(columns) - set(self._schema.names)\n    if wrong_columns:\n      raise Exception(f\"Specified columns {list(wrong_columns)} not in schema.\")\n\n  def serve(self):\n    self.reader = _Reader(location=self.LOCATION, ds=self)\n    self.reader.serve()\n\n  def _create_dataset(self):\n    return pads.dataset(\n      source=random.sample(self._files, len(self._files))[0],\n      format=\"parquet\",\n      filesystem=self._fs,\n      exclude_invalid_files=False,\n    )\n\n  def to_batches(self):\n    \"\"\"This allows the init to control reading settings.\n\n    Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.\n\n    Perform `drop_remainder` behavior to afix batch size.\n    This does not shift our data distribution bc of volume and file-level shuffling on every repeat.\n    \"\"\"\n    batch_size = self._dataset_kwargs[\"batch_size\"]\n    while True:\n      ds = self._create_dataset()\n      for batch in ds.to_batches(**self._dataset_kwargs):\n        if batch.num_rows < batch_size:\n          logging.info(f\"Dropping remainder ({batch.num_rows}/{batch_size})\")\n          break\n        yield batch\n\n  @abc.abstractmethod\n  def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:\n    raise NotImplementedError\n\n  def dataloader(self, remote: bool = False):\n    if not remote:\n      return map(self.pa_to_batch, self.to_batches())\n    readers = get_readers(2)\n    return map(self.pa_to_batch, reader_utils.roundrobin(*readers))\n\n\nGRPC_OPTIONS = [\n  (\"GRPC_ARG_KEEPALIVE_TIME_MS\", 60000),\n  (\"GRPC_ARG_MIN_RECONNECT_BACKOFF_MS\", 2000),\n  (\"GRPC_ARG_MAX_METADATA_SIZE\", 1024 * 1024 * 1024),\n]\n\n\ndef get_readers(num_readers_per_worker: int):\n  addresses = env.get_flight_server_addresses()\n\n  readers = []\n  for worker in addresses:\n    logging.info(f\"Attempting connection to reader {worker}.\")\n    client = pa.flight.connect(worker, generic_options=GRPC_OPTIONS)\n    client.wait_for_available(60)\n    reader = client.do_get(None).to_reader()\n    logging.info(f\"Connected reader to {worker}.\")\n    readers.append(reader)\n  return readers\n"
  },
  {
    "path": "reader/dds.py",
    "content": "\"\"\"Dataset service orchestrated by a TFJob\n\"\"\"\nfrom typing import Optional\nimport uuid\n\nfrom tml.ml_logging.torch_logging import logging\nimport tml.machines.environment as env\n\nimport packaging.version\nimport tensorflow as tf\n\ntry:\n  import tensorflow_io as tfio\nexcept:\n  pass\nfrom tensorflow.python.data.experimental.ops.data_service_ops import (\n  _from_dataset_id,\n  _register_dataset,\n)\nimport torch.distributed as dist\n\n\ndef maybe_start_dataset_service():\n  if not env.has_readers():\n    return\n\n  if packaging.version.parse(tf.__version__) < packaging.version.parse(\"2.5\"):\n    raise Exception(f\"maybe_distribute_dataset requires TF >= 2.5; got {tf.__version__}\")\n\n  if env.is_dispatcher():\n    logging.info(f\"env.get_reader_port() = {env.get_reader_port()}\")\n    logging.info(f\"env.get_dds_journaling_dir() = {env.get_dds_journaling_dir()}\")\n    work_dir = env.get_dds_journaling_dir()\n    server = tf.data.experimental.service.DispatchServer(\n      tf.data.experimental.service.DispatcherConfig(\n        port=env.get_reader_port(),\n        protocol=\"grpc\",\n        work_dir=work_dir,\n        fault_tolerant_mode=bool(work_dir),\n      )\n    )\n    server.join()\n\n  elif env.is_reader():\n    logging.info(f\"env.get_reader_port() = {env.get_reader_port()}\")\n    logging.info(f\"env.get_dds_dispatcher_address() = {env.get_dds_dispatcher_address()}\")\n    logging.info(f\"env.get_dds_worker_address() = {env.get_dds_worker_address()}\")\n    server = tf.data.experimental.service.WorkerServer(\n      tf.data.experimental.service.WorkerConfig(\n        port=env.get_reader_port(),\n        dispatcher_address=env.get_dds_dispatcher_address(),\n        worker_address=env.get_dds_worker_address(),\n        protocol=\"grpc\",\n      )\n    )\n    server.join()\n\n\ndef register_dataset(\n  dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = \"AUTO\"\n):\n  if dist.get_rank() == 0:\n    dataset_id = _register_dataset(\n      service=dataset_service,\n      dataset=dataset,\n      compression=compression,\n    )\n    job_name = uuid.uuid4().hex[:8]\n    id_and_job = [dataset_id.numpy(), job_name]\n    logging.info(f\"rank{dist.get_rank()}: Created dds job with {dataset_id.numpy()}, {job_name}\")\n  else:\n    id_and_job = [None, None]\n\n  dist.broadcast_object_list(id_and_job, src=0)\n  return tuple(id_and_job)\n\n\ndef distribute_from_dataset_id(\n  dataset_service: str,\n  dataset_id: int,\n  job_name: Optional[str],\n  compression: Optional[str] = \"AUTO\",\n  prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,\n) -> tf.data.Dataset:\n  logging.info(f\"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}\")\n  dataset = _from_dataset_id(\n    processing_mode=\"parallel_epochs\",\n    service=dataset_service,\n    dataset_id=dataset_id,\n    job_name=job_name,\n    element_spec=None,\n    compression=compression,\n  )\n  if prefetch is not None:\n    dataset = dataset.prefetch(prefetch)\n  return dataset\n\n\ndef maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:\n  \"\"\"Torch-compatible and distributed-training-aware dataset service distributor.\n\n  - rank 0 process will register the given dataset.\n  - rank 0 process will broadcast job name and dataset id.\n  - all rank processes will consume from the same job/dataset.\n\n  Without this, dataset workers will try to serve 1 job per rank process and OOM.\n\n  \"\"\"\n  if not env.has_readers():\n    return dataset\n  dataset_service = env.get_dds()\n\n  logging.info(f\"using DDS = {dataset_service}\")\n  dataset_id, job_name = register_dataset(dataset=dataset, dataset_service=dataset_service)\n  dataset = distribute_from_dataset_id(\n    dataset_service=dataset_service, dataset_id=dataset_id, job_name=job_name\n  )\n  return dataset\n\n\nif __name__ == \"__main__\":\n  maybe_start_dataset_service()\n"
  },
  {
    "path": "reader/test_dataset.py",
    "content": "import multiprocessing as mp\nimport os\nfrom unittest.mock import patch\n\nimport tml.reader.utils as reader_utils\nfrom tml.reader.dataset import Dataset\n\nimport pyarrow as pa\nimport pyarrow.parquet as pq\nimport pytest\nimport torch\n\n\ndef create_dataset(tmpdir):\n\n  table = pa.table(\n    {\n      \"year\": [2020, 2022, 2021, 2022, 2019, 2021],\n      \"n_legs\": [2, 2, 4, 4, 5, 100],\n    }\n  )\n  file_path = tmpdir\n  pq.write_to_dataset(table, root_path=str(file_path))\n\n  class MockDataset(Dataset):\n    def __init__(self, *args, **kwargs):\n      super().__init__(*args, **kwargs)\n      self._pa_to_batch = reader_utils.create_default_pa_to_batch(self._schema)\n\n    def pa_to_batch(self, batch):\n      return self._pa_to_batch(batch)\n\n  return MockDataset(file_pattern=str(file_path / \"*\"), batch_size=2)\n\n\ndef test_dataset(tmpdir):\n  ds = create_dataset(tmpdir)\n  batch = next(iter(ds.dataloader(remote=False)))\n  assert batch.batch_size == 2\n  assert torch.equal(batch.year, torch.Tensor([2020, 2022]))\n  assert torch.equal(batch.n_legs, torch.Tensor([2, 2]))\n\n\n@pytest.mark.skipif(\n  os.environ.get(\"GITHUB_WORKSPACE\") is not None,\n  reason=\"Multiprocessing doesn't work on github yet.\",\n)\ndef test_distributed_dataset(tmpdir):\n  MOCK_ENV = {\"TEMP_SLURM_NUM_READERS\": \"1\"}\n\n  def _client():\n    with patch.dict(os.environ, MOCK_ENV):\n      with patch(\n        \"tml.reader.dataset.env.get_flight_server_addresses\", return_value=[\"grpc://localhost:2222\"]\n      ):\n        ds = create_dataset(tmpdir)\n        batch = next(iter(ds.dataloader(remote=True)))\n        assert batch.batch_size == 2\n        assert torch.equal(batch.year, torch.Tensor([2020, 2022]))\n        assert torch.equal(batch.n_legs, torch.Tensor([2, 2]))\n\n  def _worker():\n    ds = create_dataset(tmpdir)\n    ds.serve()\n\n  worker = mp.Process(target=_worker)\n  client = mp.Process(target=_client)\n  worker.start()\n  client.start()\n  client.join()\n  assert not client.exitcode\n  worker.kill()\n  client.kill()\n"
  },
  {
    "path": "reader/test_utils.py",
    "content": "import tml.reader.utils as reader_utils\n\n\ndef test_rr():\n  options = [\"a\", \"b\", \"c\"]\n  rr = reader_utils.roundrobin(options)\n  for i, v in enumerate(rr):\n    assert v == options[i % 3]\n    if i > 4:\n      break\n"
  },
  {
    "path": "reader/utils.py",
    "content": "\"\"\"Reader utilities.\"\"\"\nimport itertools\nimport time\nfrom typing import Optional\n\nfrom tml.common.batch import DataclassBatch\nfrom tml.ml_logging.torch_logging import logging\n\nimport pyarrow as pa\nimport torch\n\n\ndef roundrobin(*iterables):\n  \"\"\"Round robin through provided iterables, useful for simple load balancing.\n\n  Adapted from https://docs.python.org/3/library/itertools.html.\n\n  \"\"\"\n  num_active = len(iterables)\n  nexts = itertools.cycle(iter(it).__next__ for it in iterables)\n  while num_active:\n    try:\n      for _next in nexts:\n        result = _next()\n        yield result\n    except StopIteration:\n      # Remove the iterator we just exhausted from the cycle.\n      num_active -= 1\n      nexts = itertools.cycle(itertools.islice(nexts, num_active))\n      logging.warning(f\"Iterable exhausted, {num_active} iterables left.\")\n    except Exception as exc:\n      logging.warning(f\"Iterable raised exception {exc}, ignoring.\")\n      # continue\n      raise\n\n\ndef speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):\n  num_examples = 0\n  prev = time.perf_counter()\n  for idx, batch in enumerate(data_loader):\n    if idx > max_steps:\n      break\n    if peek and idx % peek == 0:\n      logging.info(f\"Batch: {batch}\")\n    num_examples += batch.batch_size\n    if idx % frequency == 0:\n      now = time.perf_counter()\n      elapsed = now - prev\n      logging.info(\n        f\"step: {idx}, \"\n        f\"elapsed(s): {elapsed}, \"\n        f\"examples: {num_examples}, \"\n        f\"ex/s: {num_examples / elapsed}, \"\n      )\n      prev = now\n      num_examples = 0\n\n\ndef pa_to_torch(array: pa.array) -> torch.Tensor:\n  return torch.from_numpy(array.to_numpy())\n\n\ndef create_default_pa_to_batch(schema) -> DataclassBatch:\n  \"\"\" \"\"\"\n  _CustomBatch = DataclassBatch.from_schema(\"DefaultBatch\", schema=schema)\n\n  def get_imputation_value(pa_type):\n    type_map = {\n      pa.float64(): pa.scalar(0, type=pa.float64()),\n      pa.int64(): pa.scalar(0, type=pa.int64()),\n      pa.string(): pa.scalar(\"\", type=pa.string()),\n    }\n    if pa_type not in type_map:\n      raise Exception(f\"Imputation for type {pa_type} not supported.\")\n    return type_map[pa_type]\n\n  def _impute(array: pa.array) -> pa.array:\n    return array.fill_null(get_imputation_value(array.type))\n\n  def _column_to_tensor(record_batch: pa.RecordBatch):\n    tensors = {\n      col_name: pa_to_torch(_impute(record_batch.column(col_name)))\n      for col_name in record_batch.schema.names\n    }\n    return _CustomBatch(**tensors)\n\n  return _column_to_tensor\n"
  },
  {
    "path": "tools/pq.py",
    "content": "\"\"\"Local reader of parquet files.\n\n1. Make sure you are initialized locally:\n  ```\n  ./images/init_venv_macos.sh\n  ```\n2. Activate\n  ```\n  source ~/tml_venv/bin/activate\n  ```\n3. Use tool, e.g.\n\n  `head` prints the first `--num` rows of the dataset.\n  ```\n  python3 tools/pq.py \\\n    --num 5 --path \"tweet_eng/small/edges/all/*\" \\\n    head\n  ```\n\n  `distinct` prints the observed values in the first `--num` rows for the specified columns.\n  ```\n  python3 tools/pq.py \\\n    --num 1000000000 --columns '[\"rel\"]' \\\n    --path \"tweet_eng/small/edges/all/*\" \\\n    distinct\n  ```\n\n\"\"\"\nfrom typing import List, Optional\n\nfrom tml.common.filesystem import infer_fs\n\nimport fire\nimport pandas as pd\nimport pyarrow as pa\nimport pyarrow.dataset as pads\nimport pyarrow.parquet as pq\n\n\ndef _create_dataset(path: str):\n  fs = infer_fs(path)\n  files = fs.glob(path)\n  return pads.dataset(files, format=\"parquet\", filesystem=fs)\n\n\nclass PqReader:\n  def __init__(\n    self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None\n  ):\n    self._ds = _create_dataset(path)\n    self._batch_size = batch_size\n    self._num = num\n    self._columns = columns\n\n  def __iter__(self):\n    batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)\n    rows_seen = 0\n    for count, record in enumerate(batches):\n      if self._num and rows_seen >= self._num:\n        break\n      yield record\n      rows_seen += record.data.num_rows\n\n  def _head(self):\n    total_read = self._num * self.bytes_per_row\n    if total_read >= int(500e6):\n      raise Exception(\n        \"Sorry you're trying to read more than 500 MB \" f\"into memory ({total_read} bytes).\"\n      )\n    return self._ds.head(self._num, columns=self._columns)\n\n  @property\n  def bytes_per_row(self) -> int:\n    nbits = 0\n    for t in self._ds.schema.types:\n      try:\n        nbits += t.bit_width\n      except:\n        # Just estimate size if it is variable\n        nbits += 8\n    return nbits // 8\n\n  def schema(self):\n    print(f\"\\n# Schema\\n{self._ds.schema}\")\n\n  def head(self):\n    \"\"\"Displays first --num rows.\"\"\"\n    print(self._head().to_pandas())\n\n  def distinct(self):\n    \"\"\"Displays unique values seen in specified columns in the first `--num` rows.\n\n    Useful for getting an approximate vocabulary for certain columns.\n\n    \"\"\"\n    for col_name, column in zip(self._head().column_names, self._head().columns):\n      print(col_name)\n      print(\"unique:\", column.unique().to_pylist())\n\n\nif __name__ == \"__main__\":\n  pd.set_option(\"display.max_columns\", None)\n  pd.set_option(\"display.max_rows\", None)\n  fire.Fire(PqReader)\n"
  }
]