Full Code of twitter/the-algorithm-ml for AI

main b85210863f7a cached
111 files
376.7 KB
98.5k tokens
456 symbols
1 requests
Download .txt
Showing preview only (404K chars total). Download the full file or copy to clipboard to get everything.
Repository: twitter/the-algorithm-ml
Branch: main
Commit: b85210863f7a
Files: 111
Total size: 376.7 KB

Directory structure:
gitextract_800cdojr/

├── .github/
│   └── workflows/
│       └── main.yml
├── .gitignore
├── .pre-commit-config.yaml
├── COPYING
├── LICENSE.torchrec
├── README.md
├── common/
│   ├── __init__.py
│   ├── batch.py
│   ├── checkpointing/
│   │   ├── __init__.py
│   │   └── snapshot.py
│   ├── device.py
│   ├── filesystem/
│   │   ├── __init__.py
│   │   ├── test_infer_fs.py
│   │   └── util.py
│   ├── log_weights.py
│   ├── modules/
│   │   └── embedding/
│   │       ├── config.py
│   │       └── embedding.py
│   ├── run_training.py
│   ├── test_device.py
│   ├── testing_utils.py
│   ├── utils.py
│   └── wandb.py
├── core/
│   ├── __init__.py
│   ├── config/
│   │   ├── __init__.py
│   │   ├── base_config.py
│   │   ├── base_config_test.py
│   │   ├── config_load.py
│   │   ├── test_config_load.py
│   │   └── training.py
│   ├── custom_training_loop.py
│   ├── debug_training_loop.py
│   ├── loss_type.py
│   ├── losses.py
│   ├── metric_mixin.py
│   ├── metrics.py
│   ├── test_metrics.py
│   ├── test_train_pipeline.py
│   └── train_pipeline.py
├── images/
│   ├── init_venv.sh
│   └── requirements.txt
├── machines/
│   ├── environment.py
│   ├── get_env.py
│   ├── is_venv.py
│   └── list_ops.py
├── metrics/
│   ├── __init__.py
│   ├── aggregation.py
│   ├── auroc.py
│   └── rce.py
├── ml_logging/
│   ├── __init__.py
│   ├── absl_logging.py
│   ├── test_torch_logging.py
│   └── torch_logging.py
├── model.py
├── optimizers/
│   ├── __init__.py
│   ├── config.py
│   └── optimizer.py
├── projects/
│   ├── __init__.py
│   ├── home/
│   │   └── recap/
│   │       ├── FEATURES.md
│   │       ├── README.md
│   │       ├── __init__.py
│   │       ├── config/
│   │       │   ├── home_recap_2022/
│   │       │   │   └── segdense.json
│   │       │   └── local_prod.yaml
│   │       ├── config.py
│   │       ├── data/
│   │       │   ├── __init__.py
│   │       │   ├── config.py
│   │       │   ├── dataset.py
│   │       │   ├── generate_random_data.py
│   │       │   ├── preprocessors.py
│   │       │   ├── tfe_parsing.py
│   │       │   └── util.py
│   │       ├── embedding/
│   │       │   └── config.py
│   │       ├── main.py
│   │       ├── model/
│   │       │   ├── __init__.py
│   │       │   ├── config.py
│   │       │   ├── entrypoint.py
│   │       │   ├── feature_transform.py
│   │       │   ├── mask_net.py
│   │       │   ├── mlp.py
│   │       │   ├── model_and_loss.py
│   │       │   └── numeric_calibration.py
│   │       ├── optimizer/
│   │       │   ├── __init__.py
│   │       │   ├── config.py
│   │       │   └── optimizer.py
│   │       └── script/
│   │           ├── create_random_data.sh
│   │           └── run_local.sh
│   └── twhin/
│       ├── README.md
│       ├── config/
│       │   └── local.yaml
│       ├── config.py
│       ├── data/
│       │   ├── config.py
│       │   ├── data.py
│       │   ├── edges.py
│       │   ├── test_data.py
│       │   └── test_edges.py
│       ├── machines.yaml
│       ├── metrics.py
│       ├── models/
│       │   ├── config.py
│       │   ├── models.py
│       │   └── test_models.py
│       ├── optimizer.py
│       ├── run.py
│       ├── scripts/
│       │   ├── docker_run.sh
│       │   └── run_in_docker.sh
│       └── test_optimizer.py
├── pyproject.toml
├── reader/
│   ├── __init__.py
│   ├── dataset.py
│   ├── dds.py
│   ├── test_dataset.py
│   ├── test_utils.py
│   └── utils.py
└── tools/
    └── pq.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/main.yml
================================================
name: Python package

on: [push]

jobs:
  build:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ["3.10"]

    steps:
      - uses: actions/checkout@v3
      # - uses: pre-commit/action@v3.0.0
      #   name: Run pre-commit checks (pylint/yapf/isort)
      #   env:
      #     SKIP: insert-license
      #   with:
      #     extra_args: --hook-stage push --all-files
      - uses: actions/setup-python@v4
        with:
          python-version: "3.10"
          cache: "pip" # caching pip dependencies
      - name: install packages
        run: |
          /usr/bin/python -m pip install --upgrade pip
          pip install --no-deps -r images/requirements.txt
          # - name: ssh access
          #   uses: lhotari/action-upterm@v1
          #   with:
          #     limit-access-to-actor: true
          #     limit-access-to-users: arashd
      - name: run tests
        run: |
          # Environment variables are reset in between steps.
          mkdir /tmp/github_testing
          ln -s $GITHUB_WORKSPACE /tmp/github_testing/tml
          export PYTHONPATH="/tmp/github_testing:$PYTHONPATH"
          pytest -vv


================================================
FILE: .gitignore
================================================
# Mac
.DS_Store

# Vim
*.py.swp

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
.hypothesis

venv


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
-   repo: https://github.com/pausan/cblack
    rev: release-22.3.0
    hooks:
    - id: cblack
      name: cblack
      description: "Black: The uncompromising Python code formatter - 2 space indent fork"
      entry: cblack . -l 100
-   repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v2.3.0
    hooks:
    -   id: trailing-whitespace
    -   id: end-of-file-fixer
    -   id: check-yaml
    -   id: check-added-large-files
    -   id: check-merge-conflict


================================================
FILE: COPYING
================================================
                    GNU AFFERO GENERAL PUBLIC LICENSE
                       Version 3, 19 November 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.

                            Preamble

  The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.

  The licenses for most software and other practical works are designed
to take away your freedom to share and change the works.  By contrast,
our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.

  When we speak of free software, we are referring to freedom, not
price.  Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.

  Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.

  A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate.  Many developers of free software are heartened and
encouraged by the resulting cooperation.  However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.

  The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community.  It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server.  Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.

  An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals.  This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.

  The precise terms and conditions for copying, distribution and
modification follow.

                       TERMS AND CONDITIONS

  0. Definitions.

  "This License" refers to version 3 of the GNU Affero General Public License.

  "Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.

  "The Program" refers to any copyrightable work licensed under this
License.  Each licensee is addressed as "you".  "Licensees" and
"recipients" may be individuals or organizations.

  To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy.  The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.

  A "covered work" means either the unmodified Program or a work based
on the Program.

  To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy.  Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.

  To "convey" a work means any kind of propagation that enables other
parties to make or receive copies.  Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.

  An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License.  If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.

  1. Source Code.

  The "source code" for a work means the preferred form of the work
for making modifications to it.  "Object code" means any non-source
form of a work.

  A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.

  The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form.  A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.

  The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities.  However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work.  For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.

  The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.

  The Corresponding Source for a work in source code form is that
same work.

  2. Basic Permissions.

  All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met.  This License explicitly affirms your unlimited
permission to run the unmodified Program.  The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work.  This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.

  You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force.  You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright.  Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.

  Conveying under any other circumstances is permitted solely under
the conditions stated below.  Sublicensing is not allowed; section 10
makes it unnecessary.

  3. Protecting Users' Legal Rights From Anti-Circumvention Law.

  No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.

  When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.

  4. Conveying Verbatim Copies.

  You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.

  You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.

  5. Conveying Modified Source Versions.

  You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:

    a) The work must carry prominent notices stating that you modified
    it, and giving a relevant date.

    b) The work must carry prominent notices stating that it is
    released under this License and any conditions added under section
    7.  This requirement modifies the requirement in section 4 to
    "keep intact all notices".

    c) You must license the entire work, as a whole, under this
    License to anyone who comes into possession of a copy.  This
    License will therefore apply, along with any applicable section 7
    additional terms, to the whole of the work, and all its parts,
    regardless of how they are packaged.  This License gives no
    permission to license the work in any other way, but it does not
    invalidate such permission if you have separately received it.

    d) If the work has interactive user interfaces, each must display
    Appropriate Legal Notices; however, if the Program has interactive
    interfaces that do not display Appropriate Legal Notices, your
    work need not make them do so.

  A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit.  Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.

  6. Conveying Non-Source Forms.

  You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:

    a) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by the
    Corresponding Source fixed on a durable physical medium
    customarily used for software interchange.

    b) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by a
    written offer, valid for at least three years and valid for as
    long as you offer spare parts or customer support for that product
    model, to give anyone who possesses the object code either (1) a
    copy of the Corresponding Source for all the software in the
    product that is covered by this License, on a durable physical
    medium customarily used for software interchange, for a price no
    more than your reasonable cost of physically performing this
    conveying of source, or (2) access to copy the
    Corresponding Source from a network server at no charge.

    c) Convey individual copies of the object code with a copy of the
    written offer to provide the Corresponding Source.  This
    alternative is allowed only occasionally and noncommercially, and
    only if you received the object code with such an offer, in accord
    with subsection 6b.

    d) Convey the object code by offering access from a designated
    place (gratis or for a charge), and offer equivalent access to the
    Corresponding Source in the same way through the same place at no
    further charge.  You need not require recipients to copy the
    Corresponding Source along with the object code.  If the place to
    copy the object code is a network server, the Corresponding Source
    may be on a different server (operated by you or a third party)
    that supports equivalent copying facilities, provided you maintain
    clear directions next to the object code saying where to find the
    Corresponding Source.  Regardless of what server hosts the
    Corresponding Source, you remain obligated to ensure that it is
    available for as long as needed to satisfy these requirements.

    e) Convey the object code using peer-to-peer transmission, provided
    you inform other peers where the object code and Corresponding
    Source of the work are being offered to the general public at no
    charge under subsection 6d.

  A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.

  A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling.  In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage.  For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product.  A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.

  "Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source.  The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.

  If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information.  But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).

  The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed.  Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.

  Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.

  7. Additional Terms.

  "Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law.  If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.

  When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it.  (Additional permissions may be written to require their own
removal in certain cases when you modify the work.)  You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.

  Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:

    a) Disclaiming warranty or limiting liability differently from the
    terms of sections 15 and 16 of this License; or

    b) Requiring preservation of specified reasonable legal notices or
    author attributions in that material or in the Appropriate Legal
    Notices displayed by works containing it; or

    c) Prohibiting misrepresentation of the origin of that material, or
    requiring that modified versions of such material be marked in
    reasonable ways as different from the original version; or

    d) Limiting the use for publicity purposes of names of licensors or
    authors of the material; or

    e) Declining to grant rights under trademark law for use of some
    trade names, trademarks, or service marks; or

    f) Requiring indemnification of licensors and authors of that
    material by anyone who conveys the material (or modified versions of
    it) with contractual assumptions of liability to the recipient, for
    any liability that these contractual assumptions directly impose on
    those licensors and authors.

  All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10.  If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term.  If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.

  If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.

  Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.

  8. Termination.

  You may not propagate or modify a covered work except as expressly
provided under this License.  Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).

  However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.

  Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.

  Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License.  If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.

  9. Acceptance Not Required for Having Copies.

  You are not required to accept this License in order to receive or
run a copy of the Program.  Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance.  However,
nothing other than this License grants you permission to propagate or
modify any covered work.  These actions infringe copyright if you do
not accept this License.  Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.

  10. Automatic Licensing of Downstream Recipients.

  Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License.  You are not responsible
for enforcing compliance by third parties with this License.

  An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations.  If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.

  You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License.  For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.

  11. Patents.

  A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based.  The
work thus licensed is called the contributor's "contributor version".

  A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version.  For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.

  Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.

  In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement).  To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.

  If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients.  "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.

  If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.

  A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License.  You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.

  Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.

  12. No Surrender of Others' Freedom.

  If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License.  If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all.  For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.

  13. Remote Network Interaction; Use with the GNU General Public License.

  Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software.  This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.

  Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work.  The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.

  14. Revised Versions of this License.

  The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time.  Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.

  Each version is given a distinguishing version number.  If the
Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation.  If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.

  If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.

  Later license versions may give you additional or different
permissions.  However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.

  15. Disclaimer of Warranty.

  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.

  16. Limitation of Liability.

  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.

  17. Interpretation of Sections 15 and 16.

  If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.

                     END OF TERMS AND CONDITIONS

            How to Apply These Terms to Your New Programs

  If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.

  To do so, attach the following notices to the program.  It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.

    <one line to give the program's name and a brief idea of what it does.>
    Copyright (C) <year>  <name of author>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

Also add information on how to contact you by electronic and paper mail.

  If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source.  For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code.  There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.

  You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.


================================================
FILE: LICENSE.torchrec
================================================
A few files here (where it is specifically noted in comments) are based on code from torchrec but
adapted for our use. Torchrec license is below:


BSD 3-Clause License

Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
This project open sources some of the ML models used at Twitter.

Currently these are:

1. The "For You" Heavy Ranker (projects/home/recap).

2. TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387


This project can be run inside a python virtualenv. We have only tried this on Linux machines and because we use torchrec it works best with an Nvidia GPU. To setup run

`./images/init_venv.sh` (Linux only).

The READMEs of each project contain instructions about how to run each project.


================================================
FILE: common/__init__.py
================================================


================================================
FILE: common/batch.py
================================================
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import annotations
from typing import Dict
import abc
from dataclasses import dataclass
import dataclasses

import torch
from torchrec.streamable import Pipelineable


class BatchBase(Pipelineable, abc.ABC):
  @abc.abstractmethod
  def as_dict(self) -> Dict:
    raise NotImplementedError

  def to(self, device: torch.device, non_blocking: bool = False):
    args = {}
    for feature_name, feature_value in self.as_dict().items():
      args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
    return self.__class__(**args)

  def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
    for feature_value in self.as_dict().values():
      feature_value.record_stream(stream)

  def pin_memory(self):
    args = {}
    for feature_name, feature_value in self.as_dict().items():
      args[feature_name] = feature_value.pin_memory()
    return self.__class__(**args)

  def __repr__(self) -> str:
    def obj2str(v):
      return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"

    return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])

  @property
  def batch_size(self) -> int:
    for tensor in self.as_dict().values():
      if tensor is None:
        continue
      if not isinstance(tensor, torch.Tensor):
        continue
      return tensor.shape[0]
    raise Exception("Could not determine batch size from tensors.")


@dataclass
class DataclassBatch(BatchBase):
  @classmethod
  def feature_names(cls):
    return list(cls.__dataclass_fields__.keys())

  def as_dict(self):
    return {
      feature_name: getattr(self, feature_name)
      for feature_name in self.feature_names()
      if hasattr(self, feature_name)
    }

  @staticmethod
  def from_schema(name: str, schema):
    """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
    return dataclasses.make_dataclass(
      cls_name=name,
      fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
      bases=(DataclassBatch,),
    )

  @staticmethod
  def from_fields(name: str, fields: dict):
    return dataclasses.make_dataclass(
      cls_name=name,
      fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
      bases=(DataclassBatch,),
    )


class DictionaryBatch(BatchBase, dict):
  def as_dict(self) -> Dict:
    return self


================================================
FILE: common/checkpointing/__init__.py
================================================
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot


================================================
FILE: common/checkpointing/snapshot.py
================================================
import os
import time
from typing import Any, Dict, List, Optional

from tml.ml_logging.torch_logging import logging
from tml.common.filesystem import infer_fs, is_gcs_fs

import torchsnapshot


DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://"


class Snapshot:
  """Checkpoints using torchsnapshot.

  Also saves step to be updated by the training loop.

  """

  def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
    self.save_dir = save_dir
    self.state = state
    self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)

  @property
  def step(self):
    return self.state["extra_state"]["step"]

  @step.setter
  def step(self, step: int) -> None:
    self.state["extra_state"]["step"] = step

  @property
  def walltime(self):
    return self.state["extra_state"]["walltime"]

  @walltime.setter
  def walltime(self, walltime: float) -> None:
    self.state["extra_state"]["walltime"] = walltime

  def save(self, global_step: int) -> "PendingSnapshot":
    """Saves checkpoint with given global_step."""
    path = os.path.join(self.save_dir, str(global_step))
    logging.info(f"Saving snapshot global_step {global_step} to {path}.")
    start_time = time.time()
    # Take a snapshot in async manner, the snapshot is consistent that state changes after this method returns have no effect on the snapshot. It performs storage I/O in the background.
    snapshot = torchsnapshot.Snapshot.async_take(
      app_state=self.state,
      path=path,
      # commented out because DistributedModelParallel model saving
      # errors with this on multi-GPU. With it removed, CPU, single
      # GPU, and multi-GPU training all successfully checkpoint.
      # replicated=["**"],
    )
    logging.info(f"Snapshot saved to {snapshot.path} ({time.time() - start_time:.05}s")
    return snapshot

  def restore(self, checkpoint: str) -> None:
    """Restores a given checkpoint."""
    snapshot = torchsnapshot.Snapshot(path=checkpoint)
    logging.info(f"Restoring snapshot from {snapshot.path}.")
    start_time = time.time()
    # We can remove the try-except when we are confident that we no longer need to restore from
    # checkpoints from before walltime was added
    try:
      # checkpoints that do not have extra_state[walltime] will fail here
      snapshot.restore(self.state)
    except RuntimeError:
      # extra_state[walltime] does not exist in the checkpoint, but step should be there so restore it
      self.state["extra_state"] = torchsnapshot.StateDict(step=0)
      snapshot.restore(self.state)
      # we still need to ensure that extra_state has walltime in it
      self.state["extra_state"] = torchsnapshot.StateDict(step=self.step, walltime=0.0)

    logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s")

  @classmethod
  def get_torch_snapshot(
    cls,
    snapshot_path: str,
    global_step: Optional[int] = None,
    missing_ok: bool = False,
  ) -> torchsnapshot.Snapshot:
    """Get torch stateless snapshot, without actually loading it.
    Args:
      snapshot_path: path to the model snapshot
      global_step: restores from this checkpoint if specified.
      missing_ok: if True and checkpoints do not exist, returns without restoration.
    """
    path = get_checkpoint(snapshot_path, global_step, missing_ok)
    logging.info(f"Loading snapshot from {path}.")
    return torchsnapshot.Snapshot(path=path)

  @classmethod
  def load_snapshot_to_weight(
    cls,
    embedding_snapshot: torchsnapshot.Snapshot,
    snapshot_emb_name: str,
    weight_tensor,
  ) -> None:
    """Loads pretrained embedding from the snapshot to the model.
       Utilise partial lodaing meachanism from torchsnapshot.
    Args:
      embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
      snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
      weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
    """
    start_time = time.time()
    manifest = embedding_snapshot.get_manifest()
    for path in manifest.keys():
      if path.startswith("0") and snapshot_emb_name in path:
        snapshot_path_to_load = path
    embedding_snapshot.read_object(snapshot_path_to_load, weight_tensor)
    logging.info(
      f"Loaded embedding snapshot from {snapshot_path_to_load}: {time.time() - start_time:.05}s",
      rank=-1,
    )
    logging.info(f"Snapshot loaded to {weight_tensor.metadata()}", rank=-1)


def _eval_subdir(checkpoint_path: str) -> str:
  return os.path.join(checkpoint_path, DONE_EVAL_SUBDIR)


def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
  return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")


def is_done_eval(checkpoint_path: str, eval_partition: str):
  return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))


def mark_done_eval(checkpoint_path: str, eval_partition: str):
  infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))


def step_from_checkpoint(checkpoint: str) -> int:
  return int(os.path.basename(checkpoint))


def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
  """Simplified equivalent of tf.train.checkpoints_iterator.

  Args:
    seconds_to_sleep: time between polling calls.
    timeout: how long to wait for a new checkpoint.

  """

  def _poll(last_checkpoint: Optional[str] = None):
    stop_time = time.time() + timeout
    while True:
      _checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
      if not _checkpoint_path or _checkpoint_path == last_checkpoint:
        if time.time() + seconds_to_sleep > stop_time:
          logging.info(
            f"Timed out waiting for next available checkpoint from {save_dir} for {timeout}s."
          )
          return None
        logging.info(f"Waiting for next available checkpoint from {save_dir}.")
        time.sleep(seconds_to_sleep)
      else:
        logging.info(f"Found latest checkpoint {_checkpoint_path}.")
        return _checkpoint_path

  checkpoint_path = None
  while True:
    new_checkpoint = _poll(checkpoint_path)
    if not new_checkpoint:
      return
    checkpoint_path = new_checkpoint
    yield checkpoint_path


def get_checkpoint(
  save_dir: str,
  global_step: Optional[int] = None,
  missing_ok: bool = False,
) -> str:
  """Gets latest checkpoint or checkpoint at specified global_step.

  Args:
    global_step: Finds this checkpoint if specified.
    missing_ok: if True and checkpoints do not exist, returns without restoration.

  """
  checkpoints = get_checkpoints(save_dir)
  if not checkpoints:
    if not missing_ok:
      raise Exception(f"No checkpoints found at {save_dir}")
    else:
      logging.info(f"No checkpoints found for restoration at {save_dir}.")
      return ""

  if global_step is None:
    return checkpoints[-1]

  logging.info(f"Found checkpoints: {checkpoints}")
  for checkpoint in checkpoints:
    step = step_from_checkpoint(checkpoint)
    if global_step == step:
      chosen_checkpoint = checkpoint
      break
  else:
    raise Exception(f"Desired checkpoint at {global_step} not found in {save_dir}")
  return chosen_checkpoint


def get_checkpoints(save_dir: str) -> List[str]:
  """Gets all checkpoints that have been fully written."""
  checkpoints = []
  fs = infer_fs(save_dir)
  if fs.exists(save_dir):
    prefix = GCS_PREFIX if is_gcs_fs(fs) else ""
    checkpoints = list(f"{prefix}{elem}" for elem in fs.ls(save_dir, detail=False))
    # Only take checkpoints that were fully written.
    checkpoints = list(
      filter(
        lambda path: fs.exists(f"{path}/{torchsnapshot.snapshot.SNAPSHOT_METADATA_FNAME}"),
        checkpoints,
      )
    )
    checkpoints = sorted(checkpoints, key=lambda path: int(os.path.basename(path)))
  return checkpoints


def wait_for_evaluators(
  save_dir: str,
  partition_names: List[str],
  global_step: int,
  timeout: int,
) -> None:
  logging.info("Waiting for all evaluators to finish.")
  start_time = time.time()

  for checkpoint in checkpoints_iterator(save_dir):
    step = step_from_checkpoint(checkpoint)
    logging.info(f"Considering checkpoint {checkpoint} for global step {global_step}.")
    if step == global_step:
      while partition_names:
        if is_done_eval(checkpoint, partition_names[-1]):
          logging.info(
            f"Checkpoint {checkpoint} marked as finished eval for partition {partition_names[-1]} at step {step}, still waiting for {partition_names}."
          )
          partition_names.pop()

        if time.time() - start_time >= timeout:
          logging.warning(
            f"Not all evaluators finished after waiting for {time.time() - start_time}"
          )
          return
        time.sleep(10)
      logging.info("All evaluators finished.")
      return

    if time.time() - start_time >= timeout:
      logging.warning(f"Not all evaluators finished after waiting for {time.time() - start_time}")
      return


================================================
FILE: common/device.py
================================================
import os

import torch
import torch.distributed as dist


def maybe_setup_tensorflow():
  try:
    import tensorflow as tf
  except ImportError:
    pass
  else:
    tf.config.set_visible_devices([], "GPU")  # disable tf gpu


def setup_and_get_device(tf_ok: bool = True) -> torch.device:
  if tf_ok:
    maybe_setup_tensorflow()

  device = torch.device("cpu")
  backend = "gloo"
  if torch.cuda.is_available():
    rank = os.environ["LOCAL_RANK"]
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
  if not torch.distributed.is_initialized():
    dist.init_process_group(backend)

  return device


================================================
FILE: common/filesystem/__init__.py
================================================
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs


================================================
FILE: common/filesystem/test_infer_fs.py
================================================
"""Minimal test for infer_fs.

Mostly a test that it returns an object
"""
from tml.common.filesystem import infer_fs


def test_infer_fs():
  local_path = "/tmp/local_path"
  gcs_path = "gs://somebucket/somepath"

  local_fs = infer_fs(local_path)
  gcs_fs = infer_fs(gcs_path)

  # This should return two different objects
  assert local_fs != gcs_fs


================================================
FILE: common/filesystem/util.py
================================================
"""Utilities for interacting with the file systems."""
from fsspec.implementations.local import LocalFileSystem
import gcsfs


GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
LOCAL_FS = LocalFileSystem()


def infer_fs(path: str):
  if path.startswith("gs://"):
    return GCS_FS
  elif path.startswith("hdfs://"):
    # We can probably use pyarrow HDFS to support this.
    raise NotImplementedError("HDFS not yet supported")
  else:
    return LOCAL_FS


def is_local_fs(fs):
  return fs == LOCAL_FS


def is_gcs_fs(fs):
  return fs == GCS_FS


================================================
FILE: common/log_weights.py
================================================
"""For logging model weights."""
import itertools
from typing import Callable, Dict, List, Optional, Union

from tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel


def weights_to_log(
  model: torch.nn.Module,
  how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
  """Creates dict of reduced weights to log to give sense of training.

  Args:
    model: model to traverse.
    how_to_log: if a function, then applies this to every parameter, if a dict
      then only applies and logs specified parameters.

  """
  if not how_to_log:
    return

  to_log = dict()
  named_parameters = model.named_parameters()
  logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
  if isinstance(model, DistributedModelParallel):
    named_parameters = itertools.chain(
      named_parameters, model._dmp_wrapped_module.named_parameters()
    )
  logging.info(
    f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}"
  )
  for param_name, params in named_parameters:
    if callable(how_to_log):
      how = how_to_log
    else:
      how = how_to_log.get(param_name)  # type: ignore[assignment]
    if not how:
      continue  # type: ignore
    to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
  return to_log


def log_ebc_norms(
  model_state_dict,
  ebc_keys: List[str],
  sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
  """Logs the norms of the embedding tables as specified by ebc_keys.
  As of now, log average norm per rank.

  Args:
      model_state_dict: model.state_dict()
      ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
      i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
      sample_size: Limits number of rows per rank to compute average on to avoid OOM.
  """
  norm_logs = dict()
  for emb_key in ebc_keys:
    norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
    if emb_key in model_state_dict:
      emb_weight = model_state_dict[emb_key]
      try:
        emb_weight_tensor = emb_weight.local_tensor()
      except AttributeError as e:
        logging.info(e)
        emb_weight_tensor = emb_weight
      logging.info("Running Tensor.detach()")
      emb_weight_tensor = emb_weight_tensor.detach()
      sample_mask = torch.randperm(emb_weight_tensor.shape[0])[
        : min(sample_size, emb_weight_tensor.shape[0])
      ]
      # WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
      # Change sample_size if the you observe frequent OOM errors or remove weight logging.
      norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
      logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
      norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))

    all_norms = [
      torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())
    ]
    dist.all_gather(all_norms, norms)
    for idx, norm in enumerate(all_norms):
      if norm != -1.0:
        norm_logs[f"{emb_key}-norm-{idx}"] = norm
  logging.info(f"Norm Logs are {norm_logs}")
  return norm_logs


================================================
FILE: common/modules/embedding/config.py
================================================
from typing import List
from enum import Enum

import tml.core.config as base_config
from tml.optimizers.config import OptimizerConfig

import pydantic


class DataType(str, Enum):
  FP32 = "fp32"
  FP16 = "fp16"


class EmbeddingSnapshot(base_config.BaseConfig):
  """Configuration for Embedding snapshot"""

  emb_name: str = pydantic.Field(
    ..., description="Name of the embedding table from the loaded snapshot"
  )
  embedding_snapshot_uri: str = pydantic.Field(
    ..., description="Path to torchsnapshot of the embedding"
  )


class EmbeddingBagConfig(base_config.BaseConfig):
  """Configuration for EmbeddingBag."""

  name: str = pydantic.Field(..., description="name of embedding bag")
  num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
  embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
  pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
  vocab: str = pydantic.Field(
    None, description="Directory to parquet files of mapping from entity ID to table index."
  )
  # make sure to use an optimizer that matches:
  # https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15
  optimizer: OptimizerConfig
  data_type: DataType


class LargeEmbeddingsConfig(base_config.BaseConfig):
  """Configuration for EmbeddingBagCollection.

  The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
  """

  tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
  tables_to_log: List[str] = pydantic.Field(
    None, description="list of embedding table names that we want to log during training"
  )


class Mode(str, Enum):
  """Job modes."""

  TRAIN = "train"
  EVALUATE = "evaluate"
  INFERENCE = "inference"


================================================
FILE: common/modules/embedding/embedding.py
================================================
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
from tml.ml_logging.torch_logging import logging

import torch
from torch import nn
import torchrec
from torchrec.modules import embedding_configs
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
import numpy as np


class LargeEmbeddings(nn.Module):
  def __init__(
    self,
    large_embeddings_config: LargeEmbeddingsConfig,
  ):
    super().__init__()

    tables = []
    for table in large_embeddings_config.tables:
      data_type = (
        embedding_configs.DataType.FP32
        if (table.data_type == DataType.FP32)
        else embedding_configs.DataType.FP16
      )

      tables.append(
        EmbeddingBagConfig(
          embedding_dim=table.embedding_dim,
          feature_names=[table.name],  # restricted to 1 feature per table for now
          name=table.name,
          num_embeddings=table.num_embeddings,
          pooling=torchrec.PoolingType.SUM,
          data_type=data_type,
        )
      )

    self.ebc = EmbeddingBagCollection(
      device="meta",
      tables=tables,
    )

    logging.info("********************** EBC named params are **********")
    logging.info(list(self.ebc.named_parameters()))

    # This hook is used to perform post-processing surgery
    # on large_embedding models to prep them for serving
    self.surgery_cut_point = torch.nn.Identity()

  def forward(
    self,
    sparse_features: KeyedJaggedTensor,
  ) -> KeyedTensor:
    pooled_embs = self.ebc(sparse_features)

    # a KeyedTensor
    return self.surgery_cut_point(pooled_embs)


================================================
FILE: common/run_training.py
================================================
import os
import subprocess
import sys
from typing import Optional

from tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]
from twitter.ml.tensorflow.experimental.distributed import utils

import torch
import torch.distributed.run


def is_distributed_worker():
  world_size = os.environ.get("WORLD_SIZE", None)
  rank = os.environ.get("RANK", None)
  return world_size is not None and rank is not None


def maybe_run_training(
  train_fn,
  module_name,
  nproc_per_node: Optional[int] = None,
  num_nodes: Optional[int] = None,
  set_python_path_in_subprocess: bool = False,
  is_chief: Optional[bool] = False,
  **training_kwargs,
):
  """Wrapper function for single node, multi-GPU Pytorch training.

  If the necessary distributed Pytorch environment variables
  (WORLD_SIZE, RANK) have been set, then this function executes
  `train_fn(**training_kwargs)`.

  Otherwise, this function calls torchrun and points at the calling module
  `module_name`.  After this call, the necessary environment variables are set
  and training will commence.

  Args:
    train_fn:  The function that is responsible for training
    module_name:  The name of the module that this function was called from;
       used to indicate torchrun entrypoint.
    nproc_per_node: Number of workers per node; supported values.
    num_nodes: Number of nodes, otherwise inferred from environment.
    is_chief: If process is running on chief.
    set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
  """

  machines = utils.machine_from_env()
  if num_nodes is None:
    num_nodes = 1
    if machines.num_workers:
      num_nodes += machines.num_workers

  if is_distributed_worker():
    # world_size, rank, etc are set; assuming any other env vars are set (checks to come)
    # start the actual training!
    train_fn(**training_kwargs)
  else:
    if nproc_per_node is None:
      if torch.cuda.is_available():
        nproc_per_node = torch.cuda.device_count()
      else:
        nproc_per_node = machines.chief.num_accelerators

    # Rejoin all arguments to send back through torchrec
    # this is a temporary measure, will replace the os.system call
    # with torchrun API calls
    args = list(f"--{key}={val}" for key, val in training_kwargs.items())

    cmd = [
      "--nnodes",
      str(num_nodes),
    ]
    if nproc_per_node:
      cmd.extend(["--nproc_per_node", str(nproc_per_node)])
    if num_nodes > 1:
      cluster_resolver = utils.cluster_resolver()
      backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
      cmd.extend(
        [
          "--rdzv_backend",
          "c10d",
          "--rdzv_id",
          backend_address,
        ]
      )
      # Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388
      if is_chief:
        cmd.extend(["--rdzv_endpoint", "localhost:2222"])
      else:
        cmd.extend(["--rdzv_endpoint", backend_address])
    else:
      cmd.append("--standalone")

    cmd.extend(
      [
        str(module_name),
        *args,
      ]
    )
    logging.info(f"""Distributed running with cmd: '{" ".join(cmd)}'""")

    # Call torchrun on this module;  will spawn new processes and re-run this
    # function, eventually calling "train_fn". The following line sets the PYTHONPATH to accommodate
    # bazel stubbing for the main binary.
    if set_python_path_in_subprocess:
      subprocess.run(["torchrun"] + cmd, env={**os.environ, "PYTHONPATH": ":".join(sys.path)})
    else:
      torch.distributed.run.main(cmd)


================================================
FILE: common/test_device.py
================================================
"""Minimal test for device.

Mostly a test that this can be imported properly even tho moved.
"""
from unittest.mock import patch

import tml.common.device as device_utils


def test_device():
  with patch("tml.common.device.dist.init_process_group"):
    device = device_utils.setup_and_get_device(tf_ok=False)
  assert device.type == "cpu"


================================================
FILE: common/testing_utils.py
================================================
from contextlib import contextmanager
import datetime
import os
from unittest.mock import patch

import torch.distributed as dist
from tml.ml_logging.torch_logging import logging


MOCK_ENV = {
  "LOCAL_RANK": "0",
  "WORLD_SIZE": "1",
  "LOCAL_WORLD_SIZE": "1",
  "MASTER_ADDR": "localhost",
  "MASTER_PORT": "29501",
  "RANK": "0",
}


@contextmanager
def mock_pg():
  with patch.dict(os.environ, MOCK_ENV):
    try:
      dist.init_process_group(
        backend="gloo",
        timeout=datetime.timedelta(1),
      )
      yield
    except:
      dist.destroy_process_group()
      raise
    finally:
      dist.destroy_process_group()


================================================
FILE: common/utils.py
================================================
import yaml
import getpass
import os
import string
from typing import Tuple, Type, TypeVar

from tml.core.config import base_config

import fsspec

C = TypeVar("C", bound=base_config.BaseConfig)


def _read_file(f):
  with fsspec.open(f) as f:
    return f.read()


def setup_configuration(
  config_type: Type[C],
  yaml_path: str,
  substitute_env_variable: bool = False,
) -> Tuple[C, str]:
  """Resolves a config at a yaml path.

  Args:
    config_type: Pydantic config class to load.
    yaml_path: yaml path of the config file.
    substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
    environment variable value whenever possible. If an environment variable doesn't exist,
    the string is left unchanged.

  Returns:
    The pydantic config object.
  """

  def _substitute(s):
    if substitute_env_variable:
      return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
    return s

  assert config_type is not None, "can't use all_config without config_type"
  content = _substitute(yaml.safe_load(_read_file(yaml_path)))
  return config_type.parse_obj(content)


================================================
FILE: common/wandb.py
================================================
from typing import Any, Dict, List

import tml.core.config as base_config

import pydantic


class WandbConfig(base_config.BaseConfig):
  host: str = pydantic.Field(
    "https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
    description="Host of Weights and Biases instance, passed to login.",
  )
  key_path: str = pydantic.Field(description="Path to key file.")

  name: str = pydantic.Field(None, description="Name of the experiment, passed to init.")
  entity: str = pydantic.Field(None, description="Name of user/service account, passed to init.")
  project: str = pydantic.Field(None, description="Name of wandb project, passed to init.")
  tags: List[str] = pydantic.Field([], description="List of tags, passed to init.")
  notes: str = pydantic.Field(None, description="Notes, passed to init.")
  metadata: Dict[str, Any] = pydantic.Field(None, description="Additional metadata to log.")


================================================
FILE: core/__init__.py
================================================


================================================
FILE: core/config/__init__.py
================================================
from tml.core.config.base_config import BaseConfig
from tml.core.config.config_load import load_config_from_yaml

# Make mypy happy by explicitly rexporting the symbols intended for end user use.
__all__ = ["BaseConfig", "load_config_from_yaml"]


================================================
FILE: core/config/base_config.py
================================================
"""Base class for all config (forbids extra fields)."""

import collections
import functools
import yaml

import pydantic


class BaseConfig(pydantic.BaseModel):
  """Base class for all derived config classes.

  This class provides some convenient functionality:
    - Disallows extra fields when constructing an object. User error
      should be reduced by exact arguments.
    - "one_of" fields. A subclass can group optional fields and enforce
      that only one of the fields be set. For example:

      ```
      class ExampleConfig(BaseConfig):
        x: int = Field(None, one_of="group_1")
        y: int = Field(None, one_of="group_1")

      ExampleConfig(x=1) # ok
      ExampleConfig(y=1) # ok
      ExampleConfig(x=1, y=1) # throws error
      ```
  """

  class Config:
    """Forbids extras."""

    extra = pydantic.Extra.forbid  # noqa

  @classmethod
  @functools.lru_cache()
  def _field_data_map(cls, field_data_name):
    """Create a map of fields with provided the field data."""
    schema = cls.schema()
    one_of = collections.defaultdict(list)
    for field, fdata in schema["properties"].items():
      if field_data_name in fdata:
        one_of[fdata[field_data_name]].append(field)
    return one_of

  @pydantic.root_validator
  def _one_of_check(cls, values):
    """Validate that all 'one of' fields are appear exactly once."""
    one_of_map = cls._field_data_map("one_of")
    for one_of, field_names in one_of_map.items():
      if sum([values.get(n, None) is not None for n in field_names]) != 1:
        raise ValueError(f"Exactly one of {','.join(field_names)} required.")
    return values

  @pydantic.root_validator
  def _at_most_one_of_check(cls, values):
    """Validate that all 'at_most_one_of' fields appear at most once."""
    at_most_one_of_map = cls._field_data_map("at_most_one_of")
    for one_of, field_names in at_most_one_of_map.items():
      if sum([values.get(n, None) is not None for n in field_names]) > 1:
        raise ValueError(f"At most one of {','.join(field_names)} can be set.")
    return values

  def pretty_print(self) -> str:
    """Return a human legible (yaml) representation of the config useful for logging."""
    return yaml.dump(self.dict())


================================================
FILE: core/config/base_config_test.py
================================================
from unittest import TestCase

from tml.core.config import BaseConfig

import pydantic


class BaseConfigTest(TestCase):
  def test_extra_forbidden(self):
    class Config(BaseConfig):
      x: int

    Config(x=1)
    with self.assertRaises(pydantic.ValidationError):
      Config(x=1, y=2)

  def test_one_of(self):
    class Config(BaseConfig):
      x: int = pydantic.Field(None, one_of="f")
      y: int = pydantic.Field(None, one_of="f")

    with self.assertRaises(pydantic.ValidationError):
      Config()
    Config(x=1)
    Config(y=1)
    with self.assertRaises(pydantic.ValidationError):
      Config(x=1, y=3)

  def test_at_most_one_of(self):
    class Config(BaseConfig):
      x: int = pydantic.Field(None, at_most_one_of="f")
      y: str = pydantic.Field(None, at_most_one_of="f")

    Config()
    Config(x=1)
    Config(y="a")
    with self.assertRaises(pydantic.ValidationError):
      Config(x=1, y="a")


================================================
FILE: core/config/config_load.py
================================================
import yaml
import string
import getpass
import os
from typing import Type

from tml.core.config.base_config import BaseConfig


def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
  """Recommend method to load a config file (a yaml file) and parse it.

  Because we have a shared filesystem the recommended route to running jobs it put modified config
  files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
  """

  def _substitute(s):
    return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())

  with open(yaml_path, "r") as f:
    raw_contents = f.read()
    obj = yaml.safe_load(_substitute(raw_contents))

  return config_type.parse_obj(obj)


================================================
FILE: core/config/test_config_load.py
================================================
from unittest import TestCase

from tml.core.config import BaseConfig, load_config_from_yaml

import pydantic
import getpass
import pydantic


class _PointlessConfig(BaseConfig):
  a: int
  user: str


def test_load_config_from_yaml(tmp_path):
  yaml_path = tmp_path.joinpath("test.yaml").as_posix()
  with open(yaml_path, "w") as yaml_file:
    yaml_file.write("""a: 3\nuser: ${USER}\n""")

  pointless_config = load_config_from_yaml(_PointlessConfig, yaml_path)

  assert pointless_config.a == 3
  assert pointless_config.user == getpass.getuser()


================================================
FILE: core/config/training.py
================================================
from typing import Any, Dict, List, Optional

from tml.common.wandb import WandbConfig
from tml.core.config import base_config
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.projects.twhin.models.config import TwhinModelConfig

import pydantic


class RuntimeConfig(base_config.BaseConfig):
  wandb: WandbConfig = pydantic.Field(None)
  enable_tensorfloat32: bool = pydantic.Field(
    False, description="Use tensorfloat32 if on Ampere devices."
  )
  enable_amp: bool = pydantic.Field(False, description="Enable automatic mixed precision.")


class TrainingConfig(base_config.BaseConfig):
  save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
  num_train_steps: pydantic.PositiveInt = 10000
  initial_checkpoint_dir: str = pydantic.Field(
    None, description="Directory of initial checkpoints", at_most_one_of="initialization"
  )
  checkpoint_every_n: pydantic.PositiveInt = 1000
  checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
    None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
  )
  train_log_every_n: pydantic.PositiveInt = 1000
  num_eval_steps: int = pydantic.Field(
    16384, description="Number of evaluation steps. If < 0 the entire dataset will be used."
  )
  eval_log_every_n: pydantic.PositiveInt = 5000

  eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60

  gradient_accumulation: int = pydantic.Field(
    None, description="Number of replica steps to accumulate gradients."
  )
  num_epochs: pydantic.PositiveInt = 1


================================================
FILE: core/custom_training_loop.py
================================================
"""Torch and torchrec specific training and evaluation loops.

Features (go/100_enablements):
    - CUDA data-fetch, compute, gradient-push overlap
    - Large learnable embeddings through torchrec
    - On/off-chief evaluation
    - Warmstart/checkpoint management
    - go/dataset-service 0-copy integration

"""
import datetime
import os
from typing import Callable, Dict, Iterable, List, Mapping, Optional


from tml.common import log_weights
import tml.common.checkpointing.snapshot as snapshot_lib
from tml.core.losses import get_global_loss_detached
from tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]
from tml.core.train_pipeline import TrainPipelineSparseDist

import tree
import torch
import torch.distributed as dist
from torch.optim.lr_scheduler import _LRScheduler
import torchmetrics as tm


def get_new_iterator(iterable: Iterable):
  """
  This obtain a new iterator from the iterable. If the iterable uses tf.data.Dataset internally,
   getting a new iterator each N steps will avoid memory leak. To avoid the memory leak
   calling iter(iterable) should return a "fresh" iterator using a fresh
   (new instance of) tf.data.Iterator.
   In particular, iterable can be a torch.utils.data.IterableDataset or a
   torch.utils.data.DataLoader.

  When using DDS, performing this reset does not change the order in which elements are received
   (excluding elements already prefetched) provided that iter(iterable) internally uses
   a new instance of tf.data.Dataset created by calling from_dataset_id.
   This requirement is satisfied by RecapDataset.
  :param iterable:
  :return:
  """
  return iter(iterable)


def _get_step_fn(pipeline, data_iterator, training: bool):
  def step_fn():
    # It turns out that model.train() and model.eval() simply switch a single field inside the model
    # class,so it's somewhat safer to wrap in here.
    if training:
      pipeline._model.train()
    else:
      pipeline._model.eval()

    outputs = pipeline.progress(data_iterator)
    return tree.map_structure(lambda elem: elem.detach(), outputs)

  return step_fn


@torch.no_grad()
def _run_evaluation(
  pipeline,
  dataset,
  eval_steps: int,
  metrics: tm.MetricCollection,
  eval_batch_size: int,
  logger=None,
):
  """Runs the evaluation loop over all evaluation iterators."""
  dataset = get_new_iterator(dataset)
  step_fn = _get_step_fn(pipeline, dataset, training=False)
  last_time = datetime.datetime.now()
  logging.info(f"Starting {eval_steps} steps of evaluation.")
  for _ in range(eval_steps):
    outputs = step_fn()
    metrics.update(outputs)
  eval_ex_per_s = (
    eval_batch_size * eval_steps / (datetime.datetime.now() - last_time).total_seconds()
  )
  logging.info(f"eval examples_per_s : {eval_ex_per_s}")
  metrics_result = metrics.compute()
  # Resetting at end to release metrics memory not in use.
  # Reset metrics to prevent accumulation between multiple evaluation splits and not report a
  # running average.
  metrics.reset()
  return metrics_result


def train(
  model: torch.nn.Module,
  optimizer: torch.optim.Optimizer,
  device: str,
  save_dir: str,
  logging_interval: int,
  train_steps: int,
  checkpoint_frequency: int,
  dataset: Iterable,
  worker_batch_size: int,
  num_workers: Optional[int] = 0,
  enable_amp: bool = False,
  initial_checkpoint_dir: Optional[str] = None,
  gradient_accumulation: Optional[int] = None,
  logger_initializer: Optional[Callable] = None,
  scheduler: _LRScheduler = None,
  metrics: Optional[tm.MetricCollection] = None,
  parameters_to_log: Optional[Dict[str, Callable]] = None,
  tables_to_log: Optional[List[str]] = None,
) -> None:
  """Runs training and eval on the given TrainPipeline

  Args:
    dataset: data iterator for the training set
    evaluation_iterators: data iterators for the different evaluation sets
    scheduler: optional learning rate scheduler
    output_transform_for_metrics: optional transformation functions to transorm the model
                                  output and labels into a format the metrics can understand
  """

  train_pipeline = TrainPipelineSparseDist(
    model=model,
    optimizer=optimizer,
    device=device,
    enable_amp=enable_amp,
    grad_accum=gradient_accumulation,
  )  # type: ignore[var-annotated]

  # We explicitly initialize optimizer state here so that checkpoint will work properly.
  if hasattr(train_pipeline._optimizer, "init_state"):
    train_pipeline._optimizer.init_state()

  save_state = {
    "model": train_pipeline._model,
    "optimizer": train_pipeline._optimizer,
    "scaler": train_pipeline._grad_scaler,
  }

  chosen_checkpoint = None
  checkpoint_handler = snapshot_lib.Snapshot(
    save_dir=save_dir,
    state=save_state,
  )

  if save_dir:
    chosen_checkpoint = snapshot_lib.get_checkpoint(save_dir=save_dir, missing_ok=True)

  start_step = 0
  start_walltime = 0.0
  if chosen_checkpoint:
    # Skip restoration and exit if we should be finished.
    chosen_checkpoint_global_step = snapshot_lib.step_from_checkpoint(chosen_checkpoint)
    if not chosen_checkpoint_global_step < dist.get_world_size() * train_steps:
      logging.info(
        "Not restoring and finishing training as latest checkpoint "
        f"{chosen_checkpoint} found "
        f"at global_step ({chosen_checkpoint_global_step}) >= "
        f"train_steps ({dist.get_world_size() * train_steps})"
      )
      return
    logging.info(f"Restoring latest checkpoint from global_step {chosen_checkpoint_global_step}")
    checkpoint_handler.restore(chosen_checkpoint)
    start_step = checkpoint_handler.step
    start_walltime = checkpoint_handler.walltime
  elif initial_checkpoint_dir:
    base, ckpt_step = os.path.split(initial_checkpoint_dir)
    warmstart_handler = snapshot_lib.Snapshot(
      save_dir=base,
      state=save_state,
    )
    ckpt = snapshot_lib.get_checkpoint(save_dir=base, missing_ok=False, global_step=int(ckpt_step))
    logging.info(
      f"Restoring from initial_checkpoint_dir: {initial_checkpoint_dir}, but keeping starting step as 0."
    )
    warmstart_handler.restore(ckpt)

  train_logger = logger_initializer(mode="train") if logger_initializer else None
  train_step_fn = _get_step_fn(train_pipeline, get_new_iterator(dataset), training=True)

  # Counting number of parameters in the model directly when creating it.
  nb_param = 0
  for p in model.parameters():
    nb_param += p.numel()
  logging.info(f"Model has {nb_param} parameters")

  last_time = datetime.datetime.now()
  start_time = last_time
  last_pending_snapshot = None
  for step in range(start_step, train_steps + 1):
    checkpoint_handler.step = step
    outputs = train_step_fn()
    step_done_time = datetime.datetime.now()
    checkpoint_handler.walltime = (step_done_time - start_time).total_seconds() + start_walltime

    if scheduler:
      scheduler.step()

    if step % logging_interval == 0:
      interval_time = (step_done_time - last_time).total_seconds()
      steps_per_s = logging_interval / interval_time
      worker_example_per_s = steps_per_s * worker_batch_size
      global_example_per_s = worker_example_per_s * (1 + (num_workers or 0))
      global_step = step

      log_values = {
        "global_step": global_step,
        "loss": get_global_loss_detached(outputs["loss"]),
        "steps_per_s": steps_per_s,
        "global_example_per_s": global_example_per_s,
        "worker_examples_per_s": worker_example_per_s,
        "active_training_walltime": checkpoint_handler.walltime,
      }
      if parameters_to_log:
        log_values.update(
          log_weights.weights_to_log(
            model=model,
            how_to_log=parameters_to_log,
          )
        )
      log_values = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), log_values)

      if tables_to_log:
        log_values.update(
          log_weights.log_ebc_norms(
            model_state_dict=train_pipeline._model.state_dict(),
            ebc_keys=tables_to_log,
          )
        )
      if train_logger:
        train_logger.log(log_values, step=global_step)
      log_line = ", ".join(f"{name}: {value}" for name, value in log_values.items())
      logging.info(f"Step: {step}, training. {log_line}")
      last_time = step_done_time

      # If we just restored, do not save again.
      if checkpoint_frequency and step > start_step and step % checkpoint_frequency == 0:
        if last_pending_snapshot and not last_pending_snapshot.done():
          logging.warning(
            "Begin a new snapshot and the last one hasn't finished. That probably indicates "
            "either you're snapshotting really often or something is wrong. Will now block and "
            "wait for snapshot to finish before beginning the next one."
          )
          last_pending_snapshot.wait()
        last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())

  # Save if we did not just save.
  if checkpoint_frequency and step % checkpoint_frequency != 0:
    # For the final save, wait for the checkpoint to write to make sure the process doesn't finish
    # before its completed.
    last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
  logging.info(f"Finished training steps: {step}, global_steps: {step * dist.get_world_size()}")

  if last_pending_snapshot:
    logging.info(f"Waiting for any checkpoints to finish.")
    last_pending_snapshot.wait()


def log_eval_results(
  results,
  eval_logger,
  partition_name: str,
  step: int,
):
  results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
  logging.info(f"Step: {step}, evaluation ({partition_name}).")
  for metric_name, metric_value in results.items():
    logging.info(f"\t{metric_name}: {metric_value:1.4e}")

  if eval_logger:
    eval_logger.log(results, step=step, commit=True)


def only_evaluate(
  model: torch.nn.Module,
  optimizer: torch.optim.Optimizer,
  device: str,
  save_dir: str,
  num_train_steps: int,
  dataset: Iterable,
  eval_batch_size: int,
  num_eval_steps: int,
  eval_timeout_in_s: int,
  eval_logger: Callable,
  partition_name: str,
  metrics: Optional[tm.MetricCollection] = None,
):
  logging.info(f"Evaluating on partition {partition_name}.")
  logging.info("Computing metrics:")
  logging.info(metrics)
  eval_pipeline = TrainPipelineSparseDist(model, optimizer, device)  # type: ignore[var-annotated]
  save_state = {
    "model": eval_pipeline._model,
    "optimizer": eval_pipeline._optimizer,
  }
  checkpoint_handler = snapshot_lib.Snapshot(
    save_dir=save_dir,
    state=save_state,
  )
  for checkpoint_path in snapshot_lib.checkpoints_iterator(save_dir, timeout=eval_timeout_in_s):
    checkpoint_handler.restore(checkpoint_path)
    step = checkpoint_handler.step
    dataset = get_new_iterator(dataset)
    results = _run_evaluation(
      pipeline=eval_pipeline,
      dataset=dataset,
      eval_steps=num_eval_steps,
      eval_batch_size=eval_batch_size,
      metrics=metrics,
    )
    log_eval_results(results, eval_logger, partition_name, step=step)
    rank = dist.get_rank() if dist.is_initialized() else 0
    if rank == 0:
      snapshot_lib.mark_done_eval(checkpoint_path, partition_name)
    if step >= num_train_steps:
      return


================================================
FILE: core/debug_training_loop.py
================================================
"""This is a very limited feature training loop useful for interactive debugging.

It is not intended for actual model tranining (it is not fast, doesn't compile the model).
It does not support checkpointing.

suggested use:

from tml.core import debug_training_loop
debug_training_loop.train(...)
"""

from typing import Iterable, Optional, Dict, Callable, List
import torch
from torch.optim.lr_scheduler import _LRScheduler
import torchmetrics as tm

from tml.ml_logging.torch_logging import logging


def train(
  model: torch.nn.Module,
  optimizer: torch.optim.Optimizer,
  train_steps: int,
  dataset: Iterable,
  scheduler: _LRScheduler = None,
  # Accept any arguments (to be compatible with the real training loop)
  # but just ignore them.
  *args,
  **kwargs,
) -> None:

  logging.warning("Running debug training loop, don't use for model training.")

  data_iter = iter(dataset)
  for step in range(0, train_steps + 1):
    x = next(data_iter)
    optimizer.zero_grad()
    loss, outputs = model.forward(x)
    loss.backward()
    optimizer.step()

    if scheduler:
      scheduler.step()

    logging.info(f"Step {step} completed. Loss = {loss}")


================================================
FILE: core/loss_type.py
================================================
"""Loss type enums."""
from enum import Enum


class LossType(str, Enum):
  CROSS_ENTROPY = "cross_entropy"
  BCE_WITH_LOGITS = "bce_with_logits"


================================================
FILE: core/losses.py
================================================
"""Loss functions -- including multi task ones."""

import typing

from tml.core.loss_type import LossType
from tml.ml_logging.torch_logging import logging

import torch


def _maybe_warn(reduction: str):
  """
  Warning for reduction different than mean.
  """
  if reduction != "mean":
    logging.warn(
      f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
      f"to the gradient without DDP only for mean reduction. If you need this property for"
      f"the provided reduction {reduction}, it needs to be implemented."
    )


def build_loss(
  loss_type: LossType,
  reduction="mean",
):
  _maybe_warn(reduction)
  f = _LOSS_TYPE_TO_FUNCTION[loss_type]

  def loss_fn(logits, labels):
    return f(logits, labels.type_as(logits), reduction=reduction)

  return loss_fn


def get_global_loss_detached(local_loss, reduction="mean"):
  """
  Perform all_reduce to obtain the global loss function using the provided reduction.
  :param local_loss: The local loss of the current rank.
  :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
  :return: The reduced & detached global loss.
  """
  if reduction != "mean":
    logging.warn(
      f"The reduction used in this function should be the same as the one used by "
      f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately"
      f"modified for reduction {reduction}."
    )

  if reduction not in ["mean", "sum"]:
    raise ValueError(f"Reduction {reduction} is currently unsupported.")

  global_loss = local_loss.detach()

  if reduction == "mean":
    global_loss.div_(torch.distributed.get_world_size())

  torch.distributed.all_reduce(global_loss)
  return global_loss


def build_multi_task_loss(
  loss_type: LossType,
  tasks: typing.List[str],
  task_loss_reduction="mean",
  global_reduction="mean",
  pos_weights=None,
):
  _maybe_warn(global_reduction)
  _maybe_warn(task_loss_reduction)
  f = _LOSS_TYPE_TO_FUNCTION[loss_type]

  loss_reduction_fns = {
    "mean": torch.mean,
    "sum": torch.sum,
    "min": torch.min,
    "max": torch.max,
    "median": torch.median,
  }

  def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
    if pos_weights is None:
      torch_weights = torch.ones([len(tasks)])
    else:
      torch_weights = torch.tensor(pos_weights)

    losses = {}
    for task_idx, task in enumerate(tasks):
      task_logits = logits[:, task_idx]
      label = labels[:, task_idx].type_as(task_logits)

      loss = f(
        task_logits,
        label,
        reduction=task_loss_reduction,
        pos_weight=torch_weights[task_idx],
        weight=weights[:, task_idx],
      )
      losses[f"loss/{task}"] = loss

    losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
    return losses

  return loss_fn


_LOSS_TYPE_TO_FUNCTION = {
  LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits
}


================================================
FILE: core/metric_mixin.py
================================================
"""
Mixin that requires a transform to munge output dictionary of tensors a
model produces to a form that the torchmetrics.Metric.update expects.

By unifying on our signature for `update`, we can also now use
torchmetrics.MetricCollection which requires all metrics have
the same call signature.

To use, override this with a transform that munges `outputs`
into a kwargs dict that the inherited metric.update accepts.

Here are two examples of how to extend torchmetrics.SumMetric so that it accepts
an output dictionary of tensors and munges it to what SumMetric expects (single `value`)
for its update method.

1. Using as a mixin to inherit from or define a new metric class.

  class Count(MetricMixin, SumMetric):
    def transform(self, outputs):
      return {'value': 1}

2. Redefine an existing metric class.

  SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})

"""
from abc import abstractmethod
from typing import Callable, Dict, List

from tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]

import torch
import torchmetrics


class MetricMixin:
  @abstractmethod
  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
    ...

  def update(self, outputs: Dict[str, torch.Tensor]):
    results = self.transform(outputs)
    # Do not try to update if any tensor is empty as a result of stratification.
    for value in results.values():
      if torch.is_tensor(value) and not value.nelement():
        return
    super().update(**results)


class TaskMixin:
  def __init__(self, task_idx: int = -1, **kwargs):
    super().__init__(**kwargs)
    self._task_idx = task_idx


class StratifyMixin:
  def __init__(
    self,
    stratifier=None,
    **kwargs,
  ):
    super().__init__(**kwargs)
    self._stratifier = stratifier

  def maybe_apply_stratification(
    self, outputs: Dict[str, torch.Tensor], value_names: List[str]
  ) -> Dict[str, torch.Tensor]:
    """Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
    outputs = outputs.copy()
    if not self._stratifier:
      return outputs
    stratifiers = outputs.get("stratifiers")
    if not stratifiers:
      return outputs
    if stratifiers.get(self._stratifier.name) is None:
      return outputs

    mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value)
    target_slice = torch.squeeze(mask.nonzero(), -1)
    for value_name in value_names:
      target = outputs[value_name]
      outputs[value_name] = torch.index_select(target, 0, target_slice)
    return outputs


def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
  """Returns new class using MetricMixin and given base_metric.

  Functionally the same using inheritance, just saves some lines of code
  if no need for class attributes.

  """

  def transform_method(_self, *args, **kwargs):
    return transform(*args, **kwargs)

  return type(
    base_metric.__name__,
    (
      MetricMixin,
      base_metric,
    ),
    {"transform": transform_method},
  )


================================================
FILE: core/metrics.py
================================================
"""Common metrics that also support multi task.

We assume multi task models will output [task_idx, ...] predictions

"""
from typing import Any, Dict

from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin

import torch
import torchmetrics as tm


def probs_and_labels(
  outputs: Dict[str, torch.Tensor],
  task_idx: int,
) -> Dict[str, torch.Tensor]:
  preds = outputs["probabilities"]
  target = outputs["labels"]
  if task_idx >= 0:
    preds = preds[:, task_idx]
    target = target[:, task_idx]
  return {
    "preds": preds,
    "target": target.int(),
  }


class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
  def transform(self, outputs):
    outputs = self.maybe_apply_stratification(outputs, ["labels"])
    value = outputs["labels"]
    if self._task_idx >= 0:
      value = value[:, self._task_idx]
    return {"value": value}


class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
  def transform(self, outputs):
    outputs = self.maybe_apply_stratification(outputs, ["labels"])
    value = outputs["labels"]
    if self._task_idx >= 0:
      value = value[:, self._task_idx]
    return {"value": value}


class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
  def transform(self, outputs):
    outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
    value = outputs["probabilities"]
    if self._task_idx >= 0:
      value = value[:, self._task_idx]
    return {"value": value}


class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
  def transform(self, outputs):
    outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
    return probs_and_labels(outputs, self._task_idx)


class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
  def transform(self, outputs):
    outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
    return probs_and_labels(outputs, self._task_idx)


class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
  def transform(self, outputs):
    outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
    return probs_and_labels(outputs, self._task_idx)


class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
  """
  Based on:
  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
  """

  def __init__(self, num_samples, **kwargs):
    super().__init__(**kwargs)
    self.num_samples = num_samples

  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    scores, labels = outputs["logits"], outputs["labels"]
    pos_scores = scores[labels == 1]
    neg_scores = scores[labels == 0]
    result = {
      "value": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))]
      > neg_scores[torch.randint(len(neg_scores), (self.num_samples,))]
    }
    return result


class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
  """
  The ranks of all positives
  Based on:
  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
  """

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    scores, labels = outputs["logits"], outputs["labels"]
    _, sorted_indices = scores.sort(descending=True)
    pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1  # all ranks start from 1
    result = {"value": pos_ranks}
    return result


class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
  """
  The reciprocal of the ranks of all
  Based on:
  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
  """

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    scores, labels = outputs["logits"], outputs["labels"]
    _, sorted_indices = scores.sort(descending=True)
    pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1  # all ranks start from 1
    result = {"value": torch.div(torch.ones_like(pos_ranks), pos_ranks)}
    return result


class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
  """
  The fraction of positives that rank in the top K among their negatives
  Note that this is basically precision@k
  Based on:
  https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
  """

  def __init__(self, k: int, **kwargs):
    super().__init__(**kwargs)
    self.k = k

  def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    scores, labels = outputs["logits"], outputs["labels"]
    _, sorted_indices = scores.sort(descending=True)
    pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1  # all ranks start from 1
    result = {"value": (pos_ranks <= self.k).float()}
    return result


================================================
FILE: core/test_metrics.py
================================================
from dataclasses import dataclass

from tml.core import metrics as core_metrics
from tml.core.metric_mixin import MetricMixin, prepend_transform

import torch
from torchmetrics import MaxMetric, MetricCollection, SumMetric


@dataclass
class MockStratifierConfig:
  name: str
  index: int
  value: int


class Count(MetricMixin, SumMetric):
  def transform(self, outputs):
    return {"value": 1}


Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})


def test_count_metric():
  num_examples = 123
  examples = [
    {"stuff": 0},
  ] * num_examples

  metric = Count()
  for outputs in examples:
    metric.update(outputs)

  assert metric.compute().item() == num_examples


def test_collections():
  max_metric = Max()
  count_metric = Count()
  metric = MetricCollection([max_metric, count_metric])

  examples = [{"value": idx} for idx in range(123)]
  for outputs in examples:
    metric.update(outputs)

  assert metric.compute() == {
    max_metric.__class__.__name__: len(examples) - 1,
    count_metric.__class__.__name__: len(examples),
  }


def test_task_dependent_ctr():
  num_examples = 144
  batch_size = 1024
  outputs = [
    {
      "stuff": 0,
      "labels": torch.arange(0, 6).repeat(batch_size, 1),
    }
    for idx in range(num_examples)
  ]

  for task_idx in range(5):
    metric = core_metrics.Ctr(task_idx=task_idx)
    for output in outputs:
      metric.update(output)
    assert metric.compute().item() == task_idx


def test_stratified_ctr():
  outputs = [
    {
      "stuff": 0,
      # [bsz, tasks]
      "labels": torch.tensor(
        [
          [0, 1, 2, 3],
          [1, 2, 3, 4],
          [2, 3, 4, 0],
        ]
      ),
      "stratifiers": {
        # [bsz]
        "level": torch.tensor(
          [9, 0, 9],
        ),
      },
    }
  ]

  stratifier = MockStratifierConfig(name="level", index=2, value=9)
  for task_idx in range(5):
    metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier)
    for output in outputs:
      metric.update(output)
    # From the dataset of:
    # [
    #   [0, 1, 2, 3],
    #   [1, 2, 3, 4],
    #   [2, 3, 4, 0],
    # ]
    # we pick out
    # [
    #   [0, 1, 2, 3],
    #   [2, 3, 4, 0],
    # ]
    # and with Ctr task_idx, we pick out
    # [
    #   [1,],
    #   [3,],
    # ]
    assert metric.compute().item() == (1 + 3) / 2


def test_auc():
  num_samples = 10000
  metric = core_metrics.Auc(num_samples)
  target = torch.tensor([0, 0, 1, 1, 1])
  preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0])
  outputs_correct = {"logits": preds_correct, "labels": target}
  preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0])
  outputs_bad = {"logits": preds_bad, "labels": target}

  metric.update(outputs_correct)
  assert metric.compute().item() == 1.0

  metric.reset()
  metric.update(outputs_bad)
  assert metric.compute().item() == 0.0


def test_pos_rank():
  metric = core_metrics.PosRanks()
  target = torch.tensor([0, 0, 1, 1, 1])
  preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
  outputs_correct = {"logits": preds_correct, "labels": target}
  preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
  outputs_bad = {"logits": preds_bad, "labels": target}

  metric.update(outputs_correct)
  assert metric.compute().item() == 2.0

  metric.reset()
  metric.update(outputs_bad)
  assert metric.compute().item() == 4.0


def test_reciprocal_rank():
  metric = core_metrics.ReciprocalRank()
  target = torch.tensor([0, 0, 1, 1, 1])
  preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
  outputs_correct = {"logits": preds_correct, "labels": target}
  preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
  outputs_bad = {"logits": preds_bad, "labels": target}

  metric.update(outputs_correct)
  assert abs(metric.compute().item() - 0.6111) < 0.001

  metric.reset()
  metric.update(outputs_bad)
  assert abs(metric.compute().item() == 0.2611) < 0.001


def test_hit_k():
  hit1_metric = core_metrics.HitAtK(1)
  target = torch.tensor([0, 0, 1, 1, 1])
  preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
  outputs_correct = {"logits": preds_correct, "labels": target}
  preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
  outputs_bad = {"logits": preds_bad, "labels": target}

  hit1_metric.update(outputs_correct)
  assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001

  hit1_metric.reset()
  hit1_metric.update(outputs_bad)

  assert hit1_metric.compute().item() == 0

  hit3_metric = core_metrics.HitAtK(3)
  hit3_metric.update(outputs_correct)
  assert (hit3_metric.compute().item() - 0.66666) < 0.0001

  hit3_metric.reset()
  hit3_metric.update(outputs_bad)
  assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001


================================================
FILE: core/test_train_pipeline.py
================================================
from dataclasses import dataclass
from typing import Tuple

from tml.common.batch import DataclassBatch
from tml.common.testing_utils import mock_pg
from tml.core import train_pipeline

import torch
from torchrec.distributed import DistributedModelParallel


@dataclass
class MockDataclassBatch(DataclassBatch):
  continuous_features: torch.Tensor
  labels: torch.Tensor


class MockModule(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.model = torch.nn.Linear(10, 1)
    self.loss_fn = torch.nn.BCEWithLogitsLoss()

  def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
    pred = self.model(batch.continuous_features)
    loss = self.loss_fn(pred, batch.labels)
    return (loss, pred)


def create_batch(bsz: int):
  return MockDataclassBatch(
    continuous_features=torch.rand(bsz, 10).float(),
    labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
  )


def test_sparse_pipeline():
  device = torch.device("cpu")
  model = MockModule().to(device)

  steps = 8
  example = create_batch(1)
  dataloader = iter(example for _ in range(steps + 2))

  results = []
  with mock_pg():
    d_model = DistributedModelParallel(model)
    pipeline = train_pipeline.TrainPipelineSparseDist(
      model=d_model,
      optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
      device=device,
      grad_accum=2,
    )
    for _ in range(steps):
      results.append(pipeline.progress(dataloader))

  results = [elem.detach().numpy() for elem in results]
  # Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
  for first, second in zip(results[::2], results[1::2]):
    assert first == second, results

  # Check we do update gradients, i.e. results do change for every 1th and 2nd.
  for first, second in zip(results[1::2], results[2::2]):
    assert first != second, results


def test_amp():
  device = torch.device("cpu")
  model = MockModule().to(device)

  steps = 8
  example = create_batch(1)
  dataloader = iter(example for _ in range(steps + 2))

  results = []
  with mock_pg():
    d_model = DistributedModelParallel(model)
    pipeline = train_pipeline.TrainPipelineSparseDist(
      model=d_model,
      optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
      device=device,
      enable_amp=True,
      # Not supported on CPU.
      enable_grad_scaling=False,
    )
    for _ in range(steps):
      results.append(pipeline.progress(dataloader))

  results = [elem.detach() for elem in results]
  for value in results:
    assert value.dtype == torch.bfloat16


================================================
FILE: core/train_pipeline.py
================================================
"""
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
with TrainPipelineSparseDist.progress modified to support gradient accumulation.

"""
import abc
from dataclasses import dataclass, field
import logging
from typing import (
  Any,
  cast,
  Dict,
  Generic,
  Iterator,
  List,
  Optional,
  Set,
  Tuple,
  TypeVar,
)

import torch
from torch.autograd.profiler import record_function
from torch.fx.node import Node
from torchrec.distributed.model_parallel import (
  DistributedModelParallel,
  ShardedModule,
)
from torchrec.distributed.types import Awaitable
from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor
from torchrec.streamable import Multistreamable, Pipelineable


logger: logging.Logger = logging.getLogger(__name__)


In = TypeVar("In", bound=Pipelineable)
Out = TypeVar("Out")


class TrainPipeline(abc.ABC, Generic[In, Out]):
  @abc.abstractmethod
  def progress(self, dataloader_iter: Iterator[In]) -> Out:
    pass


def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
  assert isinstance(
    batch, (torch.Tensor, Pipelineable)
  ), f"{type(batch)} must implement Pipelineable interface"
  return cast(In, batch.to(device=device, non_blocking=non_blocking))


def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
  if stream is None:
    return
  torch.cuda.current_stream().wait_stream(stream)
  # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
  # PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
  # freed, its memory is likely to be reused by newly constructed tenosrs.  By default,
  # this allocator traces whether a tensor is still in use by only the CUDA stream where it
  # was created.   When a tensor is used by additional CUDA streams, we need to call record_stream
  # to tell the allocator about all these streams.  Otherwise, the allocator might free the
  # underlying memory of the tensor once it is no longer used by the creator stream.  This is
  # a notable programming trick when we write programs using multi CUDA streams.
  cur_stream = torch.cuda.current_stream()
  assert isinstance(
    batch, (torch.Tensor, Multistreamable)
  ), f"{type(batch)} must implement Multistreamable interface"
  batch.record_stream(cur_stream)


class TrainPipelineBase(TrainPipeline[In, Out]):
  """
  This class runs training iterations using a pipeline of two stages, each as a CUDA
  stream, namely, the current (default) stream and `self._memcpy_stream`. For each
  iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
  memory, and the default stream runs forward, backward, and optimization.
  """

  def __init__(
    self,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
  ) -> None:
    self._model = model
    self._optimizer = optimizer
    self._device = device
    self._memcpy_stream: Optional[torch.cuda.streams.Stream] = (
      torch.cuda.Stream() if device.type == "cuda" else None
    )
    self._cur_batch: Optional[In] = None
    self._connected = False

  def _connect(self, dataloader_iter: Iterator[In]) -> None:
    cur_batch = next(dataloader_iter)
    self._cur_batch = cur_batch
    with torch.cuda.stream(self._memcpy_stream):
      self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
    self._connected = True

  def progress(self, dataloader_iter: Iterator[In]) -> Out:
    if not self._connected:
      self._connect(dataloader_iter)

    # Fetch next batch
    with record_function("## next_batch ##"):
      next_batch = next(dataloader_iter)
    cur_batch = self._cur_batch
    assert cur_batch is not None

    if self._model.training:
      with record_function("## zero_grad ##"):
        self._optimizer.zero_grad()

    with record_function("## wait_for_batch ##"):
      _wait_for_batch(cur_batch, self._memcpy_stream)

    with record_function("## forward ##"):
      (losses, output) = self._model(cur_batch)

    if self._model.training:
      with record_function("## backward ##"):
        torch.sum(losses, dim=0).backward()

    # Copy the next batch to GPU
    self._cur_batch = cur_batch = next_batch
    with record_function("## copy_batch_to_gpu ##"):
      with torch.cuda.stream(self._memcpy_stream):
        self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)

    # Update
    if self._model.training:
      with record_function("## optimizer ##"):
        self._optimizer.step()

    return output


class Tracer(torch.fx.Tracer):
  # Disable proxying buffers during tracing. Ideally, proxying buffers would
  # be disabled, but some models are currently mutating buffer values, which
  # causes errors during tracing. If those models can be rewritten to not do
  # that, we can likely remove this line
  proxy_buffer_attributes = False

  def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
    super().__init__()
    self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []

  def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
    if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
      return True
    return super().is_leaf_module(m, module_qualified_name)


@dataclass
class TrainPipelineContext:
  # pyre-ignore [4]
  input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
  module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
  # pyre-ignore [4]
  feature_processor_forwards: List[Any] = field(default_factory=list)


@dataclass
class ArgInfo:
  # attributes of input batch, e.g. batch.attr1.attr2 call
  # will produce ["attr1", "attr2"]
  input_attrs: List[str]
  # batch[attr1].attr2 will produce [True, False]
  is_getitems: List[bool]
  # name for kwarg of pipelined forward() call or None
  # for a positional arg
  name: Optional[str]


class PipelinedForward:
  def __init__(
    self,
    name: str,
    args: List[ArgInfo],
    module: ShardedModule,
    context: TrainPipelineContext,
    dist_stream: Optional[torch.cuda.streams.Stream],
  ) -> None:
    self._name = name
    self._args = args
    self._module = module
    self._context = context
    self._dist_stream = dist_stream

  # pyre-ignore [2, 24]
  def __call__(self, *input, **kwargs) -> Awaitable:
    assert self._name in self._context.input_dist_requests
    request = self._context.input_dist_requests[self._name]
    assert isinstance(request, Awaitable)
    with record_function("## wait_sparse_data_dist ##"):
      # Finish waiting on the dist_stream,
      # in case some delayed stream scheduling happens during the wait() call.
      with torch.cuda.stream(self._dist_stream):
        data = request.wait()

    # Make sure that both result of input_dist and context
    # are properly transferred to the current stream.
    if self._dist_stream is not None:
      torch.cuda.current_stream().wait_stream(self._dist_stream)
      cur_stream = torch.cuda.current_stream()

      assert isinstance(
        data, (torch.Tensor, Multistreamable)
      ), f"{type(data)} must implement Multistreamable interface"
      # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
      data.record_stream(cur_stream)

      ctx = self._context.module_contexts[self._name]
      ctx.record_stream(cur_stream)

    if len(self._context.feature_processor_forwards) > 0:
      with record_function("## feature_processor ##"):
        for sparse_feature in data:
          if sparse_feature.id_score_list_features is not None:
            for fp_forward in self._context.feature_processor_forwards:
              sparse_feature.id_score_list_features = fp_forward(
                sparse_feature.id_score_list_features
              )

    return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)

  @property
  def name(self) -> str:
    return self._name

  @property
  def args(self) -> List[ArgInfo]:
    return self._args


def _start_data_dist(
  pipelined_modules: List[ShardedModule],
  batch: In,
  context: TrainPipelineContext,
) -> None:
  context.input_dist_requests.clear()
  context.module_contexts.clear()
  for module in pipelined_modules:
    forward = module.forward
    assert isinstance(forward, PipelinedForward)

    # Retrieve argument for the input_dist of EBC
    # is_getitem True means this argument could be retrieved by a list
    # False means this argument is getting while getattr
    # and this info was done in the _rewrite_model by tracing the
    # entire model to get the arg_info_list
    args = []
    kwargs = {}
    for arg_info in forward.args:
      if arg_info.input_attrs:
        arg = batch
        for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):
          if is_getitem:
            arg = arg[attr]
          else:
            arg = getattr(arg, attr)
        if arg_info.name:
          kwargs[arg_info.name] = arg
        else:
          args.append(arg)
      else:
        args.append(None)
    # Start input distribution.
    module_ctx = module.create_context()
    context.module_contexts[forward.name] = module_ctx
    context.input_dist_requests[forward.name] = module.input_dist(module_ctx, *args, **kwargs)

  # Call wait on the first awaitable in the input dist for the tensor splits
  for key, awaitable in context.input_dist_requests.items():
    context.input_dist_requests[key] = awaitable.wait()


def _get_node_args_helper(
  # pyre-ignore
  arguments,
  num_found: int,
  feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
  """
  Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
  It also counts the number of (args + kwargs) found.
  """

  arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
  for arg, arg_info in zip(arguments, arg_info_list):
    if arg is None:
      num_found += 1
      continue
    while True:
      if not isinstance(arg, torch.fx.Node):
        break
      child_node = arg

      if child_node.op == "placeholder":
        num_found += 1
        break
      # skip this fp node
      elif feature_processor_arguments is not None and child_node in feature_processor_arguments:
        arg = child_node.args[0]
      elif (
        child_node.op == "call_function"
        and child_node.target.__module__ == "builtins"
        # pyre-ignore[16]
        and child_node.target.__name__ == "getattr"
      ):
        arg_info.input_attrs.insert(0, child_node.args[1])
        arg_info.is_getitems.insert(0, False)
        arg = child_node.args[0]
      elif (
        child_node.op == "call_function"
        and child_node.target.__module__ == "_operator"
        # pyre-ignore[16]
        and child_node.target.__name__ == "getitem"
      ):
        arg_info.input_attrs.insert(0, child_node.args[1])
        arg_info.is_getitems.insert(0, True)
        arg = child_node.args[0]
      else:
        break
  return arg_info_list, num_found


def _get_node_args(
  node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
  num_found = 0
  pos_arg_info_list, num_found = _get_node_args_helper(
    node.args, num_found, feature_processor_nodes
  )
  kwargs_arg_info_list, num_found = _get_node_args_helper(node.kwargs.values(), num_found)

  # Replace with proper names for kwargs
  for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list):
    arg_info_list.name = name

  arg_info_list = pos_arg_info_list + kwargs_arg_info_list
  return arg_info_list, num_found


def _get_unsharded_module_names_helper(
  model: torch.nn.Module,
  path: str,
  unsharded_module_names: Set[str],
) -> bool:
  sharded_children = set()
  for name, child in model.named_children():
    curr_path = path + name
    if isinstance(child, ShardedModule):
      sharded_children.add(name)
    else:
      child_sharded = _get_unsharded_module_names_helper(
        child,
        curr_path + ".",
        unsharded_module_names,
      )
      if child_sharded:
        sharded_children.add(name)

  if len(sharded_children) > 0:
    for name, _ in model.named_children():
      if name not in sharded_children:
        unsharded_module_names.add(path + name)

  return len(sharded_children) > 0


def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
  """
  Returns a list of top level modules do not contain any sharded sub modules.
  """

  unsharded_module_names: Set[str] = set()
  _get_unsharded_module_names_helper(
    model,
    "",
    unsharded_module_names,
  )
  return list(unsharded_module_names)


def _rewrite_model(  # noqa C901
  model: torch.nn.Module,
  context: TrainPipelineContext,
  dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:

  # Get underlying nn.Module
  if isinstance(model, DistributedModelParallel):
    model = model.module

  # Collect a list of sharded modules.
  sharded_modules = {}
  fp_modules = {}
  for name, m in model.named_modules():
    if isinstance(m, ShardedModule):
      sharded_modules[name] = m
    if isinstance(m, BaseGroupedFeatureProcessor):
      fp_modules[name] = m

  # Trace a model.
  tracer = Tracer(leaf_modules=_get_unsharded_module_names(model))
  graph = tracer.trace(model)

  feature_processor_nodes = []
  # find the fp node
  for node in graph.nodes:
    if node.op == "call_module" and node.target in fp_modules:
      feature_processor_nodes.append(node)
  # Select sharded modules, which are top-level in the forward call graph,
  # i.e. which don't have input transformations, i.e.
  # rely only on 'builtins.getattr'.
  ret = []
  for node in graph.nodes:
    if node.op == "call_module" and node.target in sharded_modules:
      total_num_args = len(node.args) + len(node.kwargs)
      if total_num_args == 0:
        continue
      arg_info_list, num_found = _get_node_args(node, feature_processor_nodes)
      if num_found == total_num_args:
        logger.info(f"Module '{node.target}'' will be pipelined")
        child = sharded_modules[node.target]
        child.forward = PipelinedForward(
          node.target,
          arg_info_list,
          child,
          context,
          dist_stream,
        )
        ret.append(child)
  return ret


class TrainPipelineSparseDist(TrainPipeline[In, Out]):
  """
  This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
  forward and backward. This helps hide the all2all latency while preserving the
  training forward / backward ordering.

  stage 3: forward, backward - uses default CUDA stream
  stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
  stage 1: device transfer - uses memcpy CUDA stream

  `ShardedModule.input_dist()` is only done for top-level modules in the call graph.
  To be considered a top-level module, a module can only depend on 'getattr' calls on
  input.

  Input model must be symbolically traceable with the exception of `ShardedModule` and
  `DistributedDataParallel` modules.
  """

  synced_pipeline_id: Dict[int, int] = {}

  def __init__(
    self,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    enable_amp: bool = False,
    enable_grad_scaling: bool = True,
    grad_accum: Optional[int] = None,
  ) -> None:
    self._model = model
    self._optimizer = optimizer
    self._device = device
    self._enable_amp = enable_amp
    # NOTE: Pending upstream feedback, but two flags because we can run AMP without CUDA but cannot scale gradients without CUDA.
    # Background on gradient/loss scaling
    # https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling
    # https://pytorch.org/docs/stable/amp.html#gradient-scaling
    self._enable_grad_scaling = enable_grad_scaling
    self._grad_scaler = torch.cuda.amp.GradScaler(
      enabled=self._enable_amp and self._enable_grad_scaling
    )
    logging.info(f"Amp is enabled: {self._enable_amp}")

    # use two data streams to support two concurrent batches
    if device.type == "cuda":
      self._memcpy_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
      self._data_dist_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
    else:
      if self._enable_amp:
        logging.warning("Amp is enabled, but no CUDA available")
      self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None
      self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None
    self._batch_i: Optional[In] = None
    self._batch_ip1: Optional[In] = None
    self._batch_ip2: Optional[In] = None
    self._connected = False
    self._context = TrainPipelineContext()
    self._pipelined_modules: List[ShardedModule] = []

    self._progress_calls = 0
    if grad_accum is not None:
      assert isinstance(grad_accum, int) and grad_accum > 0
    self._grad_accum = grad_accum

  def _connect(self, dataloader_iter: Iterator[In]) -> None:
    # batch 1
    with torch.cuda.stream(self._memcpy_stream):
      batch_i = next(dataloader_iter)
      self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
      # Try to pipeline input data dist.
      self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)

    with torch.cuda.stream(self._data_dist_stream):
      _wait_for_batch(batch_i, self._memcpy_stream)
      _start_data_dist(self._pipelined_modules, batch_i, self._context)

    # batch 2
    with torch.cuda.stream(self._memcpy_stream):
      batch_ip1 = next(dataloader_iter)
      self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
    self._connected = True
    self.__class__.synced_pipeline_id[id(self._model)] = id(self)

  def progress(self, dataloader_iter: Iterator[In]) -> Out:
    """
    NOTE: This method has been updated to perform gradient accumulation.
    If `_grad_accum` is set, then loss values are scaled by this amount and
    optimizer update/reset is skipped for `_grad_accum` calls of `progress`
    (congruent to training steps), and then update/reset on every `_grad_accum`th
    step.

    """
    should_step_optimizer = (
      self._grad_accum is not None
      and self._progress_calls > 0
      and (self._progress_calls + 1) % self._grad_accum == 0
    ) or self._grad_accum is None
    should_reset_optimizer = (
      self._grad_accum is not None
      and self._progress_calls > 0
      and (self._progress_calls + 2) % self._grad_accum == 0
    ) or self._grad_accum is None

    if not self._connected:
      self._connect(dataloader_iter)
    elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):
      self._sync_pipeline()
      self.__class__.synced_pipeline_id[id(self._model)] = id(self)

    if self._model.training and should_reset_optimizer:
      with record_function("## zero_grad ##"):
        self._optimizer.zero_grad()

    with record_function("## copy_batch_to_gpu ##"):
      with torch.cuda.stream(self._memcpy_stream):
        batch_ip2 = next(dataloader_iter)
        self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)
    batch_i = cast(In, self._batch_i)
    batch_ip1 = cast(In, self._batch_ip1)

    with record_function("## wait_for_batch ##"):
      _wait_for_batch(batch_i, self._data_dist_stream)

    # Forward
    with record_function("## forward ##"):
      # if using multiple streams (ie. CUDA), create an event in default stream
      # before starting forward pass
      if self._data_dist_stream:
        event = torch.cuda.current_stream().record_event()
      if self._enable_amp:
        # conditionally apply the model to the batch in the autocast context
        # it appears that `enabled=self._enable_amp` should handle this,
        # but it does not.
        with torch.autocast(
          device_type=self._device.type,
          dtype=torch.bfloat16,
          enabled=self._enable_amp,
        ):
          (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
      else:
        (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))

    # Data Distribution
    with record_function("## sparse_data_dist ##"):
      with torch.cuda.stream(self._data_dist_stream):
        _wait_for_batch(batch_ip1, self._memcpy_stream)
        # Ensure event in default stream has been called before
        # starting data dist
        if self._data_dist_stream:
          # pyre-ignore [61]: Local variable `event` is undefined, or not always defined
          self._data_dist_stream.wait_event(event)
        _start_data_dist(self._pipelined_modules, batch_ip1, self._context)

    if self._model.training:
      # Backward
      with record_function("## backward ##"):
        # Loss is normalize by number of accumulation steps.
        # The reported loss in `output['loss']` remains the unnormalized value.
        if self._grad_accum is not None:
          losses = losses / self._grad_accum
        self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()

      if should_step_optimizer:
        # Update
        with record_function("## optimizer ##"):
          self._grad_scaler.step(self._optimizer)
          self._grad_scaler.update()

    self._batch_i = batch_ip1
    self._batch_ip1 = batch_ip2

    if self._model.training:
      self._progress_calls += 1

    return output

  def _sync_pipeline(self) -> None:
    """
    Syncs `PipelinedForward` for sharded modules with context and dist stream of the
    current train pipeline. Used when switching between train pipelines for the same
    model.
    """
    for module in self._pipelined_modules:
      module.forward._context = self._context
      module.forward._dist_stream = self._data_dist_stream


================================================
FILE: images/init_venv.sh
================================================
#! /bin/sh

if [[ "$(uname)" == "Darwin" ]]; then
  echo "Only supported on Linux."
  exit 1
fi

# You may need to point this to a version of python 3.10
PYTHONBIN="/opt/ee/python/3.10/bin/python3.10"
echo Using "PYTHONBIN=$PYTHONBIN"

# Put venv in tmp, these things are not made to last, just rebuild.
VENV_PATH="$HOME/tml_venv"
rm -rf "$VENV_PATH"
"$PYTHONBIN" -m venv "$VENV_PATH"

# shellcheck source=/dev/null
. "$VENV_PATH/bin/activate"

pip --require-virtual install -U pip
pip --require-virtualenv install --no-deps -r images/requirements.txt

ln -s "$(pwd)" "$VENV_PATH/lib/python3.10/site-packages/tml"

echo "Now run source ${VENV_PATH}/bin/activate" to get going.


================================================
FILE: images/requirements.txt
================================================
absl-py==1.4.0
aiofiles==22.1.0
aiohttp==3.8.3
aiosignal==1.3.1
appdirs==1.4.4
arrow==1.2.3
asttokens==2.2.1
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.1.0
backcall==0.2.0
black==22.6.0
cachetools==5.3.0
cblack==22.6.0
certifi==2022.12.7
cfgv==3.3.1
charset-normalizer==2.1.1
click==8.1.3
cmake==3.25.0
Cython==0.29.32
decorator==5.1.1
distlib==0.3.6
distro==1.8.0
dm-tree==0.1.6
docker==6.0.1
docker-pycreds==0.4.0
docstring-parser==0.8.1
exceptiongroup==1.1.0
executing==1.2.0
fbgemm-gpu-cpu==0.3.2
filelock==3.8.2
fire==0.5.0
flatbuffers==1.12
frozenlist==1.3.3
fsspec==2022.11.0
gast==0.4.0
gcsfs==2022.11.0
gitdb==4.0.10
GitPython==3.1.31
google-api-core==2.8.2
google-auth==2.16.0
google-auth-oauthlib==0.4.6
google-cloud-core==2.3.2
google-cloud-storage==2.7.0
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.4.1
googleapis-common-protos==1.56.4
grpcio==1.51.1
h5py==3.8.0
hypothesis==6.61.0
identify==2.5.17
idna==3.4
importlib-metadata==6.0.0
iniconfig==2.0.0
iopath==0.1.10
ipdb==0.13.11
ipython==8.10.0
jedi==0.18.2
Jinja2==3.1.2
keras==2.9.0
Keras-Preprocessing==1.1.2
libclang==15.0.6.1
libcst==0.4.9
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib-inline==0.1.6
moreorless==0.4.0
multidict==6.0.4
mypy==1.0.1
mypy-extensions==0.4.3
nest-asyncio==1.5.6
ninja==1.11.1
nodeenv==1.7.0
numpy==1.22.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.3
parso==0.8.3
pathspec==0.11.0
pathtools==0.1.2
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.0.0
pluggy==1.0.0
portalocker==2.6.0
portpicker==1.5.2
pre-commit==3.0.4
prompt-toolkit==3.0.36
protobuf==3.20.2
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==10.0.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pydantic==1.9.0
pyDeprecate==0.3.2
Pygments==2.14.0
pyparsing==3.0.9
pyre-extensions==0.0.27
pytest==7.2.1
pytest-mypy==0.10.3
python-dateutil==2.8.2
pytz==2022.6
PyYAML==6.0.0
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.9
scikit-build==0.16.3
sentry-sdk==1.16.0
setproctitle==1.3.2
six==1.16.0
smmap==5.0.0
sortedcontainers==2.4.0
stack-data==0.6.2
stdlibs==2022.10.9
tabulate==0.9.0
tensorboard==2.9.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.9.3
tensorflow-estimator==2.9.0
tensorflow-io-gcs-filesystem==0.30.0
termcolor==2.2.0
toml==0.10.2
tomli==2.0.1
torch==1.13.1
torchmetrics==0.11.0
torchrec==0.3.2
torchsnapshot==0.1.0
torchx==0.3.0
tqdm==4.64.1
trailrunner==1.2.1
traitlets==5.9.0
typing-inspect==0.8.0
typing_extensions==4.4.0
urllib3==1.26.13
usort==1.0.5
virtualenv==20.19.0
wandb==0.13.11
wcwidth==0.2.6
websocket-client==1.4.2
Werkzeug==2.2.3
wrapt==1.14.1
yarl==1.8.2
zipp==3.12.1


================================================
FILE: machines/environment.py
================================================
import json
import os
from typing import List


KF_DDS_PORT: int = 5050
SLURM_DDS_PORT: int = 5051
FLIGHT_SERVER_PORT: int = 2222


def on_kf():
  return "SPEC_TYPE" in os.environ


def has_readers():
  if on_kf():
    machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
    return machines_config_env["dataset_worker"] is not None
  return os.environ.get("HAS_READERS", "False") == "True"


def get_task_type():
  if on_kf():
    return os.environ["SPEC_TYPE"]
  return os.environ["TASK_TYPE"]


def is_chief() -> bool:
  return get_task_type() == "chief"


def is_reader() -> bool:
  return get_task_type() == "datasetworker"


def is_dispatcher() -> bool:
  return get_task_type() == "datasetdispatcher"


def get_task_index():
  if on_kf():
    pod_name = os.environ["MY_POD_NAME"]
    return int(pod_name.split("-")[-1])
  else:
    raise NotImplementedError


def get_reader_port():
  if on_kf():
    return KF_DDS_PORT
  return SLURM_DDS_PORT


def get_dds():
  if not has_readers():
    return None
  dispatcher_address = get_dds_dispatcher_address()
  if dispatcher_address:
    return f"grpc://{dispatcher_address}"
  else:
    raise ValueError("Job does not have DDS.")


def get_dds_dispatcher_address():
  if not has_readers():
    return None
  if on_kf():
    job_name = os.environ["JOB_NAME"]
    dds_host = f"{job_name}-datasetdispatcher-0"
  else:
    dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
  return f"{dds_host}:{get_reader_port()}"


def get_dds_worker_address():
  if not has_readers():
    return None
  if on_kf():
    job_name = os.environ["JOB_NAME"]
    task_index = get_task_index()
    return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
  else:
    node = os.environ["SLURMD_NODENAME"]
    return f"{node}:{get_reader_port()}"


def get_num_readers():
  if not has_readers():
    return 0
  if on_kf():
    machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
    return int(machines_config_env["num_dataset_workers"] or 0)
  return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))


def get_flight_server_addresses():
  if on_kf():
    job_name = os.environ["JOB_NAME"]
    return [
      f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
      for task_index in range(get_num_readers())
    ]
  else:
    raise NotImplementedError


def get_dds_journaling_dir():
  return os.environ.get("DATASET_JOURNALING_DIR", None)


================================================
FILE: machines/get_env.py
================================================
import tml.machines.environment as env

from absl import app, flags


FLAGS = flags.FLAGS
flags.DEFINE_string("property", None, "Which property of the current environment to fetch.")


def main(argv):
  if FLAGS.property == "using_dds":
    print(f"{env.has_readers()}", flush=True)
  if FLAGS.property == "has_readers":
    print(f"{env.has_readers()}", flush=True)
  elif FLAGS.property == "get_task_type":
    print(f"{env.get_task_type()}", flush=True)
  elif FLAGS.property == "is_datasetworker":
    print(f"{env.is_reader()}", flush=True)
  elif FLAGS.property == "is_dds_dispatcher":
    print(f"{env.is_dispatcher()}", flush=True)
  elif FLAGS.property == "get_task_index":
    print(f"{env.get_task_index()}", flush=True)
  elif FLAGS.property == "get_dataset_service":
    print(f"{env.get_dds()}", flush=True)
  elif FLAGS.property == "get_dds_dispatcher_address":
    print(f"{env.get_dds_dispatcher_address()}", flush=True)
  elif FLAGS.property == "get_dds_worker_address":
    print(f"{env.get_dds_worker_address()}", flush=True)
  elif FLAGS.property == "get_dds_port":
    print(f"{env.get_reader_port()}", flush=True)
  elif FLAGS.property == "get_dds_journaling_dir":
    print(f"{env.get_dds_journaling_dir()}", flush=True)
  elif FLAGS.property == "should_start_dds":
    print(env.is_reader() or env.is_dispatcher(), flush=True)


if __name__ == "__main__":
  app.run(main)


================================================
FILE: machines/is_venv.py
================================================
"""This is intended to be run as a module.
e.g. python -m tml.machines.is_venv

Exits with 0 ii running in venv, otherwise 1.
"""

import sys
import logging


def is_venv():
  # See https://stackoverflow.com/questions/1871549/determine-if-python-is-running-inside-virtualenv
  return sys.base_prefix != sys.prefix


def _main():
  if is_venv():
    logging.info("In venv %s", sys.prefix)
    sys.exit(0)
  else:
    logging.error("Not in venv")
    sys.exit(1)


if __name__ == "__main__":
  _main()


================================================
FILE: machines/list_ops.py
================================================
"""
Simple str.split() parsing of input string

usage example:
  python list_ops.py --input_list=$INPUT [--sep=","] [--op=<len|select>] [--elem=$INDEX]

Args:
  - input_list: input string
  - sep (default ","): separator string
  - elem (default 0): integer index
  - op (default "select"): either `len` or `select`
    - len: prints len(input_list.split(sep))
    - select: prints input_list.split(sep)[elem]

Typical usage would be in a bash script, e.g.:

  LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len)

"""
import tml.machines.environment as env

from absl import app, flags


FLAGS = flags.FLAGS
flags.DEFINE_string("input_list", None, "string to parse as list")
flags.DEFINE_integer("elem", 0, "which element to take")
flags.DEFINE_string("sep", ",", "separator")
flags.DEFINE_string("op", "select", "operation to do")


def main(argv):
  split_list = FLAGS.input_list.split(FLAGS.sep)
  if FLAGS.op == "select":
    print(split_list[FLAGS.elem], flush=True)
  elif FLAGS.op == "len":
    print(len(split_list), flush=True)
  else:
    raise ValueError(f"operation {FLAGS.op} not recognized.")


if __name__ == "__main__":
  app.run(main)


================================================
FILE: metrics/__init__.py
================================================
from .aggregation import StableMean  # noqa
from .auroc import AUROCWithMWU  # noqa
from .rce import NRCE, RCE  # noqa


================================================
FILE: metrics/aggregation.py
================================================
"""
Contains aggregation metrics.
"""
from typing import Tuple, Union

import torch
import torchmetrics


def update_mean(
  current_mean: torch.Tensor,
  current_weight_sum: torch.Tensor,
  value: torch.Tensor,
  weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
  """
  Update the mean according to Welford formula:
  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
  See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
  Args:
    current_mean: The value of the current accumulated mean.
    current_weight_sum: The current weighted sum.
    value: The new value that needs to be added to get a new mean.
    weight: The weights for the new value.

  Returns: The updated mean and updated weighted sum.

  """
  weight = torch.broadcast_to(weight, value.shape)

  # Avoiding (on purpose) in-place operation when using += in case
  # current_mean and current_weight_sum share the same storage
  current_weight_sum = current_weight_sum + torch.sum(weight)
  current_mean = current_mean + torch.sum((weight / current_weight_sum) * (value - current_mean))
  return current_mean, current_weight_sum


def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
  """
  Merge the state from multiple workers.
  Args:
    state: A tensor with the first dimension indicating workers.

  Returns: The accumulated mean from all workers.

  """
  mean, weight_sum = update_mean(
    current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
    current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
    value=state[:, 0],
    weight=state[:, 1],
  )
  return torch.stack([mean, weight_sum])


class StableMean(torchmetrics.Metric):
  """
  This implements a numerical stable mean metrics computation using Welford algorithm according to
  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
  For example when using float32, the algorithm will give a valid output even if the "sum" is larger
   than the maximum float32 as far as the mean is within the limit of float32.
  See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
  """

  def __init__(self, **kwargs):
    """
    Args:
      **kwargs: Additional parameters supported by all torchmetrics.Metric.
    """
    super().__init__(**kwargs)
    self.add_state(
      "mean_and_weight_sum",
      default=torch.zeros(2),
      dist_reduce_fx=stable_mean_dist_reduce_fn,
    )

  def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
    """
    Update the current mean.
    Args:
      value: Value to update the mean with.
      weight: weight to use. Shape should be broadcastable to that of value.
    """
    mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]

    if not isinstance(weight, torch.Tensor):
      weight = torch.as_tensor(weight, dtype=value.dtype, device=value.device)

    self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] = update_mean(
      mean, weight_sum, value, torch.as_tensor(weight)
    )

  def compute(self) -> torch.Tensor:
    """
    Compute and return the accumulated mean.
    """
    return self.mean_and_weight_sum[0]


================================================
FILE: metrics/auroc.py
================================================
"""
AUROC metrics.
"""
from typing import Union

from tml.ml_logging.torch_logging import logging

import torch
import torchmetrics
from torchmetrics.utilities.data import dim_zero_cat


def _compute_helper(
  predictions: torch.Tensor,
  target: torch.Tensor,
  weights: torch.Tensor,
  max_positive_negative_weighted_sum: torch.Tensor,
  min_positive_negative_weighted_sum: torch.Tensor,
  equal_predictions_as_incorrect: bool,
) -> torch.Tensor:
  """
  Compute AUROC.
  Args:
    predictions: The predictions probabilities.
    target: The target.
    weights: The sample weights to assign to each sample in the batch.
    max_positive_negative_weighted_sum: The sum of the weights for the positive labels.
    min_positive_negative_weighted_sum:
    equal_predictions_as_incorrect: For positive & negative labels having identical scores,
     we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
     we assume that they are correct prediction (i.e weight = 0).
  """
  dim = 0

  # Sort predictions based on key (score, true_label). The order is ascending for score.
  # For true_label, order is ascending if equal_predictions_as_incorrect is True;
  # otherwise it is descending.
  target_order = torch.argsort(target, dim=dim, descending=equal_predictions_as_incorrect)
  score_order = torch.sort(torch.gather(predictions, dim, target_order), stable=True, dim=dim)[1]
  score_order = torch.gather(target_order, dim, score_order)
  sorted_target = torch.gather(target, dim, score_order)
  sorted_weights = torch.gather(weights, dim, score_order)

  negatives_from_left = torch.cumsum((1.0 - sorted_target) * sorted_weights, 0)

  numerator = torch.sum(
    sorted_weights * (sorted_target * negatives_from_left / max_positive_negative_weighted_sum)
  )

  return numerator / min_positive_negative_weighted_sum


class AUROCWithMWU(torchmetrics.Metric):
  """
  AUROC using Mann-Whitney U-test.
  See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.

  This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
  the correct AUROC even if the predicted probabilities are all close to 0.
  Currently only support binary classification.
  """

  def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
    """

    Args:
      label_threshold: Labels strictly above this threshold are considered positive labels,
                       otherwise, they are considered negative.
      raise_missing_class: If True, an error will be raise if negative or positive class is missing.
        Otherwise, we will simply log a warning.
      **kwargs: Additional parameters supported by all torchmetrics.Metric.
    """
    super().__init__(**kwargs)
    self.add_state("predictions", default=[], dist_reduce_fx="cat")
    self.add_state("target", default=[], dist_reduce_fx="cat")
    self.add_state("weights", default=[], dist_reduce_fx="cat")

    self.label_threshold = label_threshold
    self.raise_missing_class = raise_missing_class

  def update(
    self,
    predictions: torch.Tensor,
    target: torch.Tensor,
    weight: Union[float, torch.Tensor] = 1.0,
  ) -> None:
    """
    Update the current auroc.
    Args:
      predictions: Predicted values, 1D Tensor or 2D Tensor of shape batch_size x 1.
      target: Ground truth. Must have same shape as predictions.
      weight: The weight to use for the predicted values. Shape should be
      broadcastable to that of predictions.
    """
    self.predictions.append(predictions)
    self.target.append(target)
    if not isinstance(weight, torch.Tensor):
      weight = torch.as_tensor(weight, dtype=predictions.dtype, device=target.device)
    self.weights.append(torch.broadcast_to(weight, predictions.size()))

  def compute(self) -> torch.Tensor:
    """
    Compute and return the accumulated AUROC.
    """
    weights = dim_zero_cat(self.weights)
    predictions = dim_zero_cat(self.predictions)
    target = dim_zero_cat(self.target).type_as(predictions)

    negative_mask = target <= self.label_threshold
    positive_mask = torch.logical_not(negative_mask)

    if not negative_mask.any():
      msg = "Negative class missing. AUROC returned will be meaningless."
      if self.raise_missing_class:
        raise ValueError(msg)
      else:
        logging.warn(msg)
    if not positive_mask.any():
      msg = "Positive class missing. AUROC returned will be meaningless."
      if self.raise_missing_class:
        raise ValueError(msg)
      else:
        logging.warn(msg)

    weighted_actual_negative_sum = torch.sum(
      torch.where(negative_mask, weights, torch.zeros_like(weights))
    )

    weighted_actual_positive_sum = torch.sum(
      torch.where(positive_mask, weights, torch.zeros_like(weights))
    )

    max_positive_negative_weighted_sum = torch.max(
      weighted_actual_negative_sum, weighted_actual_positive_sum
    )

    min_positive_negative_weighted_sum = torch.min(
      weighted_actual_negative_sum, weighted_actual_positive_sum
    )

    # Compute auroc with the weight set to 1 when positive & negative have identical scores.
    auroc_le = _compute_helper(
      target=target,
      weights=weights,
      predictions=predictions,
      min_positive_negative_weighted_sum=min_positive_negative_weighted_sum,
      max_positive_negative_weighted_sum=max_positive_negative_weighted_sum,
      equal_predictions_as_incorrect=False,
    )

    # Compute auroc with the weight set to 0 when positive & negative have identical scores.
    auroc_lt = _compute_helper(
      target=target,
      weights=weights,
      predictions=predictions,
      min_positive_negative_weighted_sum=min_positive_negative_weighted_sum,
      max_positive_negative_weighted_sum=max_positive_negative_weighted_sum,
      equal_predictions_as_incorrect=True,
    )

    # Compute auroc with the weight set to 1/2 when positive & negative have identical scores.
    return auroc_le - (auroc_le - auroc_lt) / 2.0


================================================
FILE: metrics/rce.py
================================================
"""
Contains RCE metrics.
"""
import copy
from functools import partial
from typing import Union

from tml.metrics import aggregation

import torch
import torchmetrics


def _smooth(
  value: torch.Tensor, label_smoothing: Union[float, torch.Tensor]
) -> Union[float, torch.Tensor]:
  """
  Smooth given values.
  Args:
    value: Value to smooth.
    label_smoothing: smoothing constant.
  Returns: Smoothed values.
  """
  return value * (1.0 - label_smoothing) + 0.5 * label_smoothing


def _binary_cross_entropy_with_clipping(
  predictions: torch.Tensor,
  target: torch.Tensor,
  epsilon: Union[float, torch.Tensor],
  reduction: str = "none",
) -> torch.Tensor:
  """
  Clip Predictions and apply binary cross entropy.
  This is done to match the implementation in keras at
  https://github.com/keras-team/keras/blob/r2.9/keras/backend.py#L5294-L5300
  Args:
    predictions: Predicted probabilities.
    target: Ground truth.
    epsilon: Epsilon fuzz factor used to clip the predictions.
    reduction: The reduction method to use.

  Returns: Binary cross entropy on the clipped predictions.

  """
  predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon)
  bce = -target * torch.log(predictions + epsilon)
  bce -= (1.0 - target) * torch.log(1.0 - predictions + epsilon)
  if reduction == "mean":
    return torch.mean(bce)
  return bce


class RCE(torchmetrics.Metric):
  """
  Compute the relative cross entropy (`RCE <http://go/rce>`_).

  RCE is metric used for models predicting probability of success (p), i.e. pCTR.
  RCE represents the binary `cross entropy <https://en.wikipedia.org/wiki/Cross_entropy>` of
  the model compared to a reference straw man model.

  Binary cross entropy is defined as:

  y = label; p = prediction;
  binary cross entropy(example) = - y * log(p) - (1-y) * log(1-p)

  Where y in {0, 1}

  Cross entropy of a model is defined as:

  CE(model) = average(binary cross entropy(example))

  Over all the examples we aggregate on.

  The straw man model is quite simple, it is a constant predictor, always predicting the average
  over the labels.

  RCE of a model is defined as:

  RCE(model) = 100 * (CE(reference model) - CE(model)) / CE(reference model)

  .. note:: Maximizing the likelihood is the same as minimizing the cross entropy or maximizing
            the RCE. Since cross entropy is the average minus likelihood for the binary case.

  .. note:: Binary cross entropy of an example is non negative, and equal to the
            `KL divergence <(https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
            #Properties>`
            since p is constant, and its entropy is equal to zero.

  .. note:: 0% RCE means as good as the straw man model.
            100% means always predicts exactly the label. Namely, cross entropy of the model is
                always zero. In practice 100% is impossible to achieve due to clipping.
            Negative RCE means that the model is doing worse than the straw man.
            This usually means an un-calibrated model, namely, the average prediction
            is "far" from the average label. Examining NRCE might help identifying if that is
            the case.

  .. note:: RCE is not a "ratio" in the statistical
            `level of measurement sense <https://en.wikipedia.org/wiki/Level_of_measurement>`.
            The higher the model's RCE is the harder it is to improve it by an extra point.

            For example:
            Let CE(model) = 0.5 CE(reference model), then the RCE(model) = 50.
            Now take a "twice as good" model:
            Let CE(better model) = 0.5 CE(model) = 0.25 CE(reference model),
            then the RCE(better model) = 75 and not 100.

  .. note:: In order to keep the log function stable, typically p is limited to
            lie in [CLAMP_EPSILON, 1-CLAMP_EPSILON],
            where CLAMP_EPSILON is some small constant like: 1e-7.
            Old implementation used 1e-5 clipping by default, current uses
            tf.keras.backend.epsilon()
            whose default is 1e-7.

  .. note:: Since the reference model prediction is constant (probability),
            CE(reference model) = H(average(label))

            Where H is the standard
            `entropy <https://en.wikipedia.org/wiki/Entropy_(information_theory)>` function.

  .. note:: Must have at least 1 positive and 1 negative sample accumulated,
            or RCE will come out as NaN.
  """

  def __init__(
    self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
  ):
    """
    Args:
      from_logits: whether or not predictions are logits or probabilities.
      label_smoothing: label smoothing constant.
      epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.
      **kwargs: Additional parameters supported by all torchmetrics.Metric.
    """
    super().__init__(**kwargs)
    self.from_logits = from_logits
    self.label_smoothing = label_smoothing
    self.epsilon = epsilon
    self.kwargs = kwargs

    self.mean_label = aggregation.StableMean(**kwargs)
    self.binary_cross_entropy = aggregation.StableMean(**kwargs)

    if self.from_logits:
      self.bce_loss_fn = torch.nn.functional.binary_cross_entropy_with_logits
    else:
      self.bce_loss_fn = partial(_binary_cross_entropy_with_clipping, epsilon=self.epsilon)

    # Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.
    self.batch_metric = copy.deepcopy(self)

  def update(
    self, predictions: torch.Tensor, target: torch.Tensor, weight: float = 1.0
  ) -> torch.Tensor:
    """
    Update the current rce.
    Args:
      predictions: Predicted values.
      target: Ground truth. Should have same shape as predictions.
      weight: The weight to use for the predicted values. Shape should be broadcastable to that of
       predictions.
    """
    target = _smooth(target, self.label_smoothing)
    self.mean_label.update(target, weight)
    self.binary_cross_entropy.update(
      self.bce_loss_fn(predictions, target, reduction="none"), weight
    )

  def compute(self) -> torch.Tensor:
    """
    Compute and return the accumulated rce.
    """
    baseline_mean = self.mean_label.compute()

    baseline_ce = _binary_cross_entropy_with_clipping(
      baseline_mean, baseline_mean, reduction="mean", epsilon=self.epsilon
    )

    pred_ce = self.binary_cross_entropy.compute()

    return (1.0 - (pred_ce / baseline_ce)) * 100

  def reset(self):
    """
    Reset the metric to its initial state.
    """
    super().reset()
    self.mean_label.reset()
    self.binary_cross_entropy.reset()

  def forward(self, *args, **kwargs):
    """
    Serves the dual purpose of both computing the metric on the current batch of inputs but also
        add the batch statistics to the overall accumulating metric state.
    Input arguments are the exact same as corresponding ``update`` method.
    The returned output is the exact same as the output of ``compute``.
    """
    self.update(*args, **kwargs)
    self.batch_metric.update(*args, **kwargs)
    batch_result = self.batch_metric.compute()
    self.batch_metric.reset()
    return batch_result


class NRCE(RCE):
  """
  Calculate the RCE of the normalizes model.
  Where the normalized model prediction average is normalized to the average label seen so far.
  Namely, the the normalized model prediction:

  normalized model prediction(example) = (model prediction(example) * average(label)) /
  average(model prediction)

  Where the average is over all previously seen examples.

  .. note:: average(normalized model prediction) = average(label)

  .. note:: NRCE can be misleading since it is oblivious to mis-calibrations.
            The common interpretation of NRCE is to measure how good your model could potentially
            perform if it was well calibrated.

  .. note:: A big gap between NRCE and RCE might indicate a badly calibrated model,

  """

  def __init__(
    self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
  ):
    """

    Args:
      from_logits: whether or not predictions are logits or probabilities.
      label_smoothing: label smoothing constant.
      epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.
               It only used when computing the cross entropy but not when normalizing.
      **kwargs: Additional parameters supported by all torchmetrics.Metric.
    """
    super().__init__(from_logits=False, label_smoothing=0, epsilon=epsilon, **kwargs)
    self.nrce_from_logits = from_logits
    self.nrce_label_smoothing = label_smoothing
    self.mean_prediction = aggregation.StableMean()

    # Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.
    self.batch_metric = copy.deepcopy(self)

  def update(
    self,
    predictions: torch.Tensor,
    target: torch.Tensor,
    weight: Union[float, torch.Tensor] = 1.0,
  ):
    """
    Update the current nrce.
    Args:
      predictions: Predicted values.
      target: Ground truth. Should have same shape as predictions.
      weight: The weight to use for the predicted values. Shape should be broadcastable to that of
       predictions.
    """
    predictions = torch.sigmoid(predictions) if self.nrce_from_logits else predictions

    target = _smooth(target, self.nrce_label_smoothing)
    self.mean_label.update(target, weight)

    self.mean_prediction.update(predictions, weight)

    normalizer = self.mean_label.compute() / self.mean_prediction.compute()

    predictions = predictions * normalizer

    self.binary_cross_entropy.update(
      self.bce_loss_fn(predictions, target, reduction="none"), weight
    )

  def reset(self):
    """
    Reset the metric to its initial state.
    """
    super().reset()
    self.mean_prediction.reset()


================================================
FILE: ml_logging/__init__.py
================================================


================================================
FILE: ml_logging/absl_logging.py
================================================
"""Sets up logging through absl for training usage.

- Redirects logging to sys.stdout so that severity levels in GCP Stackdriver are accurate.

Usage:
    >>> from twitter.ml.logging.absl_logging import logging
    >>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.")

"""
import logging as py_logging
import sys

from absl import logging as logging


def setup_absl_logging():
  """Make sure that absl logging pushes to stdout rather than stderr."""
  logging.get_absl_handler().python_handler.stream = sys.stdout
  formatter = py_logging.Formatter(
    fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
  )
  logging.get_absl_handler().setFormatter(formatter)
  logging.set_verbosity(logging.INFO)


setup_absl_logging()


================================================
FILE: ml_logging/test_torch_logging.py
================================================
import unittest

from tml.ml_logging.torch_logging import logging


class Testtlogging(unittest.TestCase):
  def test_warn_once(self):
    with self.assertLogs(level="INFO") as captured_logs:
      logging.info("first info")
      logging.warning("first warning")
      logging.warning("first warning")
      logging.info("second info")

    self.assertEqual(
      captured_logs.output,
      [
        "INFO:absl:first info",
        "WARNING:absl:first warning",
        "INFO:absl:second info",
      ],
    )


================================================
FILE: ml_logging/torch_logging.py
================================================
"""Overrides absl logger to be rank-aware for distributed pytorch usage.

    >>> # in-bazel import
    >>> from twitter.ml.logging.torch_logging import logging
    >>> # out-bazel import
    >>> from ml.logging.torch_logging import logging
    >>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.")
    >>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1)

"""
import functools
from typing import Optional

from tml.ml_logging.absl_logging import logging as logging
from absl import logging as absl_logging

import torch.distributed as dist


def rank_specific(logger):
  """Ensures that we only override a given logger once."""
  if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
    return logger

  def _if_rank(logger_method, limit: Optional[int] = None):
    if limit:
      # If we are limiting redundant logs, wrap logging call with a cache
      # to not execute if already cached.
      def _wrap(_call):
        @functools.lru_cache(limit)
        def _logger_method(*args, **kwargs):
          _call(*args, **kwargs)

        return _logger_method

      logger_method = _wrap(logger_method)

    def _inner(msg, *args, rank: int = 0, **kwargs):
      if not dist.is_initialized():
        logger_method(msg, *args, **kwargs)
      elif dist.get_rank() == rank:
        logger_method(msg, *args, **kwargs)
      elif rank < 0:
        logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs)

    # Register this stack frame with absl logging so that it doesn't trample logging lines.
    absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)

    return _inner

  logger.fatal = _if_rank(logger.fatal)
  logger.error = _if_rank(logger.error)
  logger.warning = _if_rank(logger.warning, limit=1)
  logger.info = _if_rank(logger.info)
  logger.debug = _if_rank(logger.debug)
  logger.exception = _if_rank(logger.exception)

  logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True


rank_specific(logging)


================================================
FILE: model.py
================================================
"""Wraps servable model in loss and RecapBatch passing to be trainable."""
# flake8: noqa
from typing import Callable

from tml.ml_logging.torch_logging import logging  # type: ignore[attr-defined]

import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel


class ModelAndLoss(torch.nn.Module):
  # Reconsider our approach at a later date: https://ppwwyyxx.com/blog/2022/Loss-Function-Separation/

  def __init__(
    self,
    model,
    loss_fn: Callable,
  ) -> None:
    """
    Args:
      model: torch module to wrap.
      loss_fn: Function for calculating loss, should accept logits and labels.
    """
    super().__init__()
    self.model = model
    self.loss_fn = loss_fn

  def forward(self, batch: "RecapBatch"):  # type: ignore[name-defined]
    """Runs model forward and calculates loss according to given loss_fn.

    NOTE: The input signature here needs to be a Pipelineable object for
    prefetching purposes during training using torchrec's pipeline.  However
    the underlying model signature needs to be exportable to onnx, requiring
    generic python types.  see https://pytorch.org/docs/stable/onnx.html#types.

    """
    outputs = self.model(batch)
    losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())

    outputs.update(
      {
        "loss": losses,
        "labels": batch.labels,
        "weights": batch.weights,
      }
    )

    # Allow multiple losses.
    return losses, outputs


def maybe_shard_model(
  model,
  device: torch.device,
):
  """Set up and apply DistributedModelParallel to a model if running in a distributed environment.

    If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
    DistributedModelParallel.

  If not in a distributed environment, returns model directly.
  """
  if dist.is_initialized():
    logging.info("***** Wrapping in DistributedModelParallel *****")
    logging.info(f"Model before wrapping: {model}")
    model = DistributedModelParallel(
      module=model,
      device=device,
    )
    logging.info(f"Model after wrapping: {model}")

  return model


def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
  """Handy function to log the content of EBC embedding layer.
     Only works for single GPU machines.

  Args:
      weight_name: name of tensor, as defined in model
      table_name: name of the EBC table the weight is taken from
      weight_tensor: embedding weight tensor
  """
  logging.info(f"{weight_name}, {table_name}", rank=-1)
  logging.info(f"{weight_tensor.metadata()}", rank=-1)
  output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0"))
  weight_tensor.gather(out=output_tensor)
  logging.info(f"{output_tensor}", rank=-1)


================================================
FILE: optimizers/__init__.py
================================================
from tml.optimizers.optimizer import compute_lr


================================================
FILE: optimizers/config.py
================================================
"""Optimization configurations for models."""

import typing

import tml.core.config as base_config

import pydantic


class PiecewiseConstant(base_config.BaseConfig):
  learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
  learning_rate_values: typing.List[float] = pydantic.Field(None)


class LinearRampToConstant(base_config.BaseConfig):
  learning_rate: float
  num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
    description="Number of steps to ramp this up from zero."
  )


class LinearRampToCosine(base_config.BaseConfig):
  learning_rate: float
  final_learning_rate: float
  num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
    description="Number of steps to ramp this up from zero."
  )
  final_num_steps: pydantic.PositiveInt = pydantic.Field(
    description="Final number of steps where decay stops."
  )


class LearningRate(base_config.BaseConfig):
  constant: float = pydantic.Field(None, one_of="lr")
  linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
  linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
  piecewise_constant: PiecewiseConstant = pydantic.Field(None, one_of="lr")


class OptimizerAlgorithmConfig(base_config.BaseConfig):
  """Base class for optimizer configurations."""

  lr: float
  ...


class AdamConfig(OptimizerAlgorithmConfig):
  # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
  lr: float
  betas: typing.Tuple[float, float] = [0.9, 0.999]
  eps: float = 1e-7  # Numerical stability in denominator.


class SgdConfig(OptimizerAlgorithmConfig):
  lr: float
  momentum: float = 0.0


class AdagradConfig(OptimizerAlgorithmConfig):
  lr: float
  eps: float = 0


class OptimizerConfig(base_config.BaseConfig):
  learning_rate: LearningRate = pydantic.Field(
    None,
    description="Constant learning rates",
  )
  adam: AdamConfig = pydantic.Field(None, one_of="optimizer")
  sgd: SgdConfig = pydantic.Field(None, one_of="optimizer")
  adagrad: AdagradConfig = pydantic.Field(None, one_of="optimizer")


def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
  if optimizer_config.adam is not None:
    return optimizer_config.adam
  elif optimizer_config.sgd is not None:
    return optimizer_config.sgd
  elif optimizer_config.adagrad is not None:
    return optimizer_config.adagrad
  else:
    raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}")


================================================
FILE: optimizers/optimizer.py
================================================
from typing import Dict, Tuple
import math
import bisect

from tml.optimizers.config import (
  LearningRate,
  OptimizerConfig,
)

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from tml.ml_logging.torch_logging import logging


def compute_lr(lr_config, step):
  """Compute a learning rate."""
  if lr_config.constant is not None:
    return lr_config.constant
  elif lr_config.piecewise_constant is not None:
    return lr_config.piecewise_constant.learning_rate_values[
      bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
    ]
  elif lr_config.linear_ramp_to_constant is not None:
    slope = (
      lr_config.linear_ramp_to_constant.learning_rate
      / lr_config.linear_ramp_to_constant.num_ramp_steps
    )
    return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
  elif lr_config.linear_ramp_to_cosine is not None:
    cfg = lr_config.linear_ramp_to_cosine
    if step < cfg.num_ramp_steps:
      slope = cfg.learning_rate / cfg.num_ramp_steps
      return slope * step
    elif step <= cfg.final_num_steps:
      return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
        1.0
        + math.cos(
          math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
        )
      )
    else:
      return cfg.final_learning_rate
  else:
    raise ValueError(f"No option selected in lr_config, passed {lr_config}")


class LRShim(_LRScheduler):
  """Shim to get learning rates into a LRScheduler.

  This adheres to the torch.optim scheduler API and can be plugged anywhere that
  e.g. exponential decay can be used.
  """

  def __init__(
    self,
    optimizer,
    lr_dict: Dict[str, LearningRate],
    last_epoch=-1,
    verbose=False,
  ):
    self.optimizer = optimizer
    self.lr_dict = lr_dict
    self.group_names = list(self.lr_dict.keys())

    num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
    if num_param_groups != len(lr_dict):
      raise ValueError(
        f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
      )

    super().__init__(optimizer, last_epoch, verbose)

  def get_lr(self):
    if not self._get_lr_called_within_step:
      logging.warn(
        "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
        UserWarning,
      )
    return self._get_closed_form_lr()

  def _get_closed_form_lr(self):
    return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]


def get_optimizer_class(optimizer_config: OptimizerConfig):
  if optimizer_config.adam is not None:
    return torch.optim.Adam
  elif optimizer_config.sgd is not None:
    return torch.optim.SGD
  elif optimizer_config.adagrad is not None:
    return torch.optim.Adagrad


def build_optimizer(
  model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]:
  """Builds an optimizer and LR scheduler from an OptimizerConfig.
  Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
  """
  optimizer_class = get_optimizer_class(optimizer_config)
  optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
  # We're passing everything in as one group here
  scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate})
  return optimizer, scheduler


================================================
FILE: projects/__init__.py
================================================


================================================
FILE: projects/home/recap/FEATURES.md
================================================
# Overview
Below is a description of the major feature groups which are input to the Twitter Heavy Ranking model.

Note that not every request will have every feature available due to user settings or other constraints and there may be some differences in ranking "For You" based on different variables.

## Aggregate Features
Twitter's aggregate features comprise the bulk of Twitter's feature count and are generated by maintaining rolling aggregations of feature values within a specific scope within a specific time window. We compute aggregates over the long-term (50 days count) and short-term ("real-time" - under 3 days count and typically 30 mins count).

<details>
<summary><b>Show Details</b></summary>
Aggregate features are groups of multiple features generated as Cartesian crosses from a template and have the format
<table>
<tr>
<td><b>Feature Group Name</b></td>
<td><b>Engagement Scope</b></td>
<td><b>Feature To Aggregate</b></td>
<td><b>Aggregation Spec</b></td>
</tr>
</table>

<ul>
<li> The <b>Feature Group Name</b> is both the name of the aggregate feature and contains internally the aggregation scope, that is, what entities are aggregated over. 
<ul>
<li> For example, <code>"user_aggregate"</code> aggregates over unique user_ids, and <code>"user_author_aggregate"</code> aggregates over all user-author pairs. It also determines what fields the feature is joined to when being used. In the case of <code>"user_author_aggregate"</code>, the feature is joined to data corresponding to the specific user and the specific author. 
<li> The raw feature group names are often verbose and are simplified in the below presentation.
</ul>
<li> <b>Engagement Scope</b> is the subset of tweets within the aggregation scope that will be aggregated over. Typically this is the name of an output engagement, like <code>recap.engagement.is_favorited</code>. In that case, we only aggregate over Tweets which are also Liked.
<li> The <b>Feature To Aggregate</b> is the feature we are accumulating over. If this value is <code>any_feature</code>, that means we aggregate the Tweet count.  For example <code>user_aggregate_v2.pair.recap.engagement.is_favorited.any_feature.50.days.count</code> will be the number of Liked records for every user over the last 50 days.
<li> The <b>Aggregation Spec</b> is what aggregate to compute - what function and over what time window.
</ul>

For every Feature Group, we generate one feature for every possible combination of Engagement Scope, Feature To Aggregate, and Aggregation Spec. In particular, every row in the below tables generate one feature for every possible cross between columns.

<b>Example</b>:
For example, one such feature may be <code>user_aggregate_v2.pair.recap.engagement.is_favorited.engagement_features.in_network.replies.count.50.days.count</code>, which can be parsed into
<table>
<tr>
<td><b>Feature Group Name</b></td>
<td><b>Engagement Scope</b></td>
<td><b>Feature To Aggregate</b></td>
<td><b>Aggregation Spec</b></td>
</tr>
<tr>
<td><code>user_aggregate_v2.pair</code></td>
<td><code>recap.engagement.is_favorited</code></td>
<td><code>engagement_features.in_network.replies.count</code></td>
<td><code>50.days.count</code></td>
</tr>
</table>

This means that this feature aggregates
<ol>
<li> (Over every user),
<li> (Over only tweets favorited by the user),
<li> In network replies sent out by this user,
<li> (Counted over the last 50 days)
</ol>
This feature is then made available as a feature for the particular user. 

</details>

The list of our aggregate features are below:
<details>
<summary><b><code>author_aggregate</code></b></summary>
These features aggregate over the author (or original author) of a tweet. Some of the features are short-duration (30 minutes) and some longer (50 days). The features track how many of an author's tweets were engaged with.
<br>
<table>
<tr>
<td>
<code>
author (real_time)
</code>
</td>
<td>
<code>
timelines.enagagement.is_retweeted_without_quote <br>
timelines.engagement.is_clicked <br>
timelines.engagement.is_dont_like <br>
timelines.engagement.is_dwelled <br>
timelines.engagement.is_favorited <br>
timelines.engagement.is_followed <br>
timelines.engagement.is_open_linked <br>
timelines.engagement.is_photo_expanded <br>
timelines.engagement.is_profile_clicked <br>
timelines.engagement.is_quoted <br>
timelines.engagement.is_replied <br>
timelines.engagement.is_retweeted <br>
timelines.engagement.is_tweet_share_dm_clicked <br>
timelines.engagement.is_tweet_share_dm_sent <br>
timelines.engagement.is_video_playback_50 <br>
timelines.engagement.is_video_quality_viewed <br>
timelines.engagement.is_video_viewed <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
30.minutes.count
</code>
</td>
</tr>

<tr>
<td>
<code>
original_author (real_time)
</code>
</td>
<td>
<code>
timelines.enagagement.is_retweeted_without_quote <br>
timelines.engagement.is_clicked <br>
timelines.engagement.is_dont_like <br>
timelines.engagement.is_dwelled <br>
timelines.engagement.is_favorited <br>
timelines.engagement.is_followed <br>
timelines.engagement.is_open_linked <br>
timelines.engagement.is_photo_expanded <br>
timelines.engagement.is_profile_clicked <br>
timelines.engagement.is_quoted <br>
timelines.engagement.is_replied <br>
timelines.engagement.is_retweeted <br>
timelines.engagement.is_tweet_share_dm_clicked <br>
timelines.engagement.is_tweet_share_dm_sent <br>
timelines.engagement.is_video_playback_50 <br>
timelines.engagement.is_video_quality_viewed <br>
timelines.engagement.is_video_viewed <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
30.minutes.count
</code>
</td>
</tr>


<tr>
<td>
<code>
original_author (real_time)
</code>
</td>
<td>
<code>
timelines.engagement.is_share_menu_clicked <br>
timelines.engagement.is_shared <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
30.minutes.count <br>
1.days.count <br>
</code>
</td>
</tr>

<tr>
<td>
<code>
original_author
</code>
</td>
<td>
<code>
recap.engagement.is_replied_reply_favorited_by_author <br>
recap.engagement.is_replied_reply_impressed_by_author <br>
recap.engagement.is_replied_reply_replied_by_author <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
50.days.count
</code>
</td>
</tr>

</table>
</details>


<details>
<summary><b><code>author-topic_aggregate</code></b></summary>
These features aggregate over a specific tweet author and a specific topic. We only accumulate long (50 day) counts. 
<br>
<table>
<tr>
<td>
<code>
author-topic
</code>
</td>
<td>
<code>
any_label <br>
recap.engagement.is_clicked <br>
recap.engagement.is_favorited <br>
recap.engagement.is_open_linked <br>
recap.engagement.is_photo_expanded <br>
recap.engagement.is_profile_clicked <br>
recap.engagement.is_replied <br>
recap.engagement.is_retweeted <br>
recap.engagement.is_video_playback_50 <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
50.days.count
</code>
</td>
</tr>

</table>
</details>

<details>
<summary><b><code>list_aggregate</code></b></summary>
These features aggregate short term and long term engagement between a user and a list.
<br>
<table>
<tr>
<td>
<code>
user_list
</code>
</td>
<td>
<code>
any_label <br>
recap.engagement.is_clicked <br>
recap.engagement.is_favorited <br>
recap.engagement.is_open_linked <br>
recap.engagement.is_photo_expanded <br>
recap.engagement.is_profile_clicked <br>
recap.engagement.is_replied <br>
recap.engagement.is_retweeted <br>
recap.engagement.is_video_playback_50 <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
50.days.count
</code>
</td>
</tr>

<tr>
<td>
<code>
list (real_time)
</code>
</td>
<td>
<code>
timelines.engagement.is_block_clicked <br>
timelines.engagement.is_dont_like <br>
timelines.engagement.is_dwelled <br>
timelines.engagement.is_favorited <br>
timelines.engagement.is_mute_clicked <br>
timelines.engagement.is_replied <br>
timelines.engagement.is_report_tweet_clicked <br>
timelines.engagement.is_retweeted <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
30.minutes.count
</code>
</td>
</tr>

</table>
</details>


<details>
<summary><b><code>user_aggregate</code></b></summary>
These features aggregate short term and long term engagement from a specific user. 

<br>
<table>
<tr>
<td>
<code>
user_v2
</code>
</td>
<td>
<code>
any_label <br>
recap.engagement.is_favorited <br>
recap.engagement.is_photo_expanded <br>
recap.engagement.is_profile_clicked <br>
</code>
</td>
<td>
<code>
any_feature <br>
engagement_features.in_network.favorites.count <br>
engagement_features.in_network.replies.count <br>
engagement_features.in_network.retweets.count <br>
realgraph.num_favorites.days_since_last <br>
realgraph.num_favorites.elapsed_days <br>
realgraph.num_favorites.ewma <br>
realgraph.num_favorites.non_zero_days <br>
realgraph.num_inspected_tweets.days_since_last <br>
realgraph.num_inspected_tweets.elapsed_days <br>
realgraph.num_inspected_tweets.ewma <br>
realgraph.num_inspected_tweets.non_zero_days <br>
realgraph.num_mentions.days_since_last <br>
realgraph.num_mentions.elapsed_days <br>
realgraph.num_mentions.ewma <br>
realgraph.num_mentions.non_zero_days <br>
realgraph.num_profile_views.days_since_last <br>
realgraph.num_profile_views.elapsed_days <br>
realgraph.num_profile_views.ewma <br>
realgraph.num_profile_views.non_zero_days <br>
realgraph.num_retweets.days_since_last <br>
realgraph.num_retweets.elapsed_days <br>
realgraph.num_retweets.ewma <br>
realgraph.num_retweets.non_zero_days <br>
realgraph.num_tweet_clicks.days_since_last <br>
realgraph.num_tweet_clicks.elapsed_days <br>
realgraph.num_tweet_clicks.ewma <br>
realgraph.num_tweet_clicks.non_zero_days <br>
realgraph.total_dwell_time.days_since_last <br>
realgraph.total_dwell_time.elapsed_days <br>
realgraph.total_dwell_time.ewma <br>
realgraph.total_dwell_time.non_zero_days <br>
recap.earlybird.fav_count_v2 <br>
recap.earlybird.reply_count_v2 <br>
recap.earlybird.retweet_count_v2 <br>
recap.searchfeature.blender_score <br>
recap.searchfeature.fav_count <br>
recap.searchfeature.reply_count <br>
recap.searchfeature.retweet_count <br>
recap.searchfeature.text_score <br>
recap.tweetfeature.bidirectional_fav_count <br>
recap.tweetfeature.bidirectional_reply_count <br>
recap.tweetfeature.bidirectional_retweet_count <br>
recap.tweetfeature.contains_media <br>
recap.tweetfeature.conversational_count <br>
recap.tweetfeature.embeds_impression_count <br>
recap.tweetfeature.embeds_url_count <br>
recap.tweetfeature.from_mutual_follow <br>
recap.tweetfeature.has_card <br>
recap.tweetfeature.has_image <br>
recap.tweetfeature.has_link <br>
recap.tweetfeature.has_multiple_media <br>
recap.tweetfeature.has_news <br>
recap.tweetfeature.has_periscope <br>
recap.tweetfeature.has_pro_video <br>
recap.tweetfeature.has_trend <br>
recap.tweetfeature.has_video <br>
recap.tweetfeature.has_vine <br>
recap.tweetfeature.has_visible_link <br>
recap.tweetfeature.is_business_score <br>
recap.tweetfeature.is_extended_reply <br>
recap.tweetfeature.is_reply <br>
recap.tweetfeature.is_retweet <br>
recap.tweetfeature.is_sensitive <br>
recap.tweetfeature.link_count <br>
recap.tweetfeature.link_language <br>
recap.tweetfeature.match_searcher_langs <br>
recap.tweetfeature.match_searcher_main_lang <br>
recap.tweetfeature.match_ui_lang <br>
recap.tweetfeature.mention_searcher <br>
recap.tweetfeature.num_hashtags <br>
recap.tweetfeature.num_mentions <br>
recap.tweetfeature.reply_other <br>
recap.tweetfeature.reply_searcher <br>
recap.tweetfeature.retweet_other <br>
recap.tweetfeature.retweet_searcher <br>
recap.tweetfeature.tweet_count_from_user_in_snapshot <br>
recap.tweetfeature.unidirectiona_fav_count <br>
recap.tweetfeature.unidirectional_reply_count <br>
recap.tweetfeature.unidirectional_retweet_count <br>
recap.tweetfeature.user_rep <br>
recap.tweetfeature.video_view_count <br>
</code>
</td>
<td>
<code>
50.days.count<br>
50.days.sum<br>
</code>
</td>
</tr>
<tr>
<td>
<code>
user_v5
</code>
</td>
<td>
<code>
any_label <br>
recap.engagement.is_clicked<br>
recap.engagement.is_favorited<br>
recap.engagement.is_open_linked<br>
recap.engagement.is_photo_expanded<br>
recap.engagement.is_profile_clicked<br>
recap.engagement.is_replied<br>
recap.engagement.is_retweeted<br>
recap.engagement.is_video_playback_50<br>
</code>
</td>
<td>
<code>
any_feature <br>
time_features.earlybird.last_favorite_since_creation_hrs<br>
time_features.earlybird.last_quote_since_creation_hrs<br>
time_features.earlybird.last_reply_since_creation_hrs<br>
time_features.earlybird.last_retweet_since_creation_hrs<br>
time_features.earlybird.time_since_last_favorite<br>
time_features.earlybird.time_since_last_quote<br>
time_features.earlybird.time_since_last_reply<br>
time_features.earlybird.time_since_last_retweet<br>
timelines.earlybird.decayed_favorite_count<br>
timelines.earlybird.decayed_quote_count<br>
timelines.earlybird.decayed_reply_count<br>
timelines.earlybird.decayed_retweet_count<br>
timelines.earlybird.embeds_impression_count_v2<br>
timelines.earlybird.embeds_url_count_v2<br>
timelines.earlybird.fake_favorite_count<br>
timelines.earlybird.fake_quote_count<br>
timelines.earlybird.fake_reply_count<br>
timelines.earlybird.fake_retweet_count<br>
timelines.earlybird.quote_count<br>
timelines.earlybird.visible_token_ratio<br>
timelines.earlybird.weighted_fav_count<br>
timelines.earlybird.weighted_quote_count<br>
timelines.earlybird.weighted_reply_count<br>
timelines.earlybird.weighted_retweet_count<br>
</code>
</td>
<td>
<code>
50.days.count<br>
50.days.sum<br>
50.days.sumsq<br>
</code>
</td>
</tr>

<tr>
<td>
<code>
user_v6
</code>
</td>
<td>
<code>
recap.engagement.is_replied_reply_favorited_by_author<br>
recap.engagement.is_replied_reply_impressed_by_author<br>
recap.engagement.is_replied_reply_replied_by_author<br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
50.days.count
</code>
</td>
</tr>

<tr>
<td>
<code>
user (twitter_wide)
</code>
</td>
<td>
<code>
any_label<br>
recap.engagement.is_favorited<br>
recap.engagement.is_replied<br>
recap.engagement.is_retweeted<br>
</code>
</td>
<td>
<code>
any_feature <br>
recap.tweetfeature.contains_media<br>
recap.tweetfeature.has_card<br>
recap.tweetfeature.has_hashtag<br>
recap.tweetfeature.has_link<br>
recap.tweetfeature.has_mention<br>
recap.tweetfeature.is_reply<br>
timelines.earlybird.has_quote<br>
</code>
</td>
<td>
<code>
50.days.count
</code>
</td>
</tr>


<tr>
<td>
<code>
user (real_time)
</code>
</td>
<td>
<code>
timelines.enagagement.is_retweeted_without_quote<br>
timelines.engagement.is_clicked<br>
timelines.engagement.is_dont_like<br>
timelines.engagement.is_dwelled<br>
timelines.engagement.is_favorited<br>
timelines.engagement.is_followed<br>
timelines.engagement.is_open_linked<br>
timelines.engagement.is_photo_expanded<br>
timelines.engagement.is_profile_clicked<br>
timelines.engagement.is_quoted<br>
timelines.engagement.is_replied<br>
timelines.engagement.is_retweeted<br>
timelines.engagement.is_tweet_share_dm_clicked<br>
timelines.engagement.is_tweet_share_dm_sent<br>
timelines.engagement.is_video_playback_50<br>
timelines.engagement.is_video_quality_viewed<br>
timelines.engagement.is_video_viewed<br>
</code>
</td>
<td>
<code>
any_feature <br>
client_log_event.tweet.has_consumer_video<br>
client_log_event.tweet.photo_count<br>
</code>
</td>
<td>
<code>
30.minutes.count
</code>
</td>
</tr>

<tr>
<td>
<code>
user (48h_real_time_v5)
</code>
</td>
<td>
<code>
timelines.enagagement.is_retweeted_without_quote<br>
timelines.engagement.is_clicked<br>
timelines.engagement.is_dont_like<br>
timelines.engagement.is_dwelled<br>
timelines.engagement.is_favorited<br>
timelines.engagement.is_followed<br>
timelines.engagement.is_open_linked<br>
timelines.engagement.is_photo_expanded<br>
timelines.engagement.is_profile_clicked<br>
timelines.engagement.is_quoted<br>
timelines.engagement.is_replied<br>
timelines.engagement.is_retweeted<br>
timelines.engagement.is_tweet_share_dm_clicked<br>
timelines.engagement.is_tweet_share_dm_sent<br>
timelines.engagement.is_video_playback_50<br>
timelines.engagement.is_video_quality_viewed<br>
timelines.engagement.is_video_viewed<br>
</code>
</td>
<td>
<code>
any_feature <br>
client_log_event.tweet.has_consumer_video<br>
client_log_event.tweet.photo_count<br>
</code>
</td>
<td>
<code>
2.days.count
</code>
</td>
</tr>

<tr>
<td>
<code>
user (72h_real_time_v6)
</code>
</td>
<td>
<code>
timelines.engagement.is_block_clicked<br>
timelines.engagement.is_dont_like<br>
timelines.engagement.is_mute_clicked<br>
timelines.engagement.is_report_tweet_clicked<br>
</code>
</td>
<td>
<code>
timelines.author.user_state.is_user_heavy_non_tweeter<br>
timelines.author.user_state.is_user_heavy_tweeter<br>
timelines.author.user_state.is_user_light<br>
timelines.author.user_state.is_user_medium_non_tweeter<br>
timelines.author.user_state.is_user_medium_tweeter<br>
timelines.author.user_state.is_user_new<br>
</code>
</td>
<td>
<code>
3.days.count
</code>
</td>
</tr>

<tr>
<td>
<code>
user (profile_real_time_v6)
</code>
</td>
<td>
<code>
profile.engagement.is_clicked<br>
profile.engagement.is_dwelled<br>
profile.engagement.is_favorited<br>
profile.engagement.is_replied<br>
profile.engagement.is_retweeted<br>
</code>
</td>
<td>
<code>
any_feature <br>
client_log_event.tweet.has_consumer_video<br>
client_log_event.tweet.photo_count<br>
</code>
</td>
<td>
<code>
30.minutes.count
</code>
</td>
</tr>

<tr>
<td>
<code>
user (real_time)
</code>
</td>
<td>
<code>
timelines.engagement.is_share_menu_clicked<br>
timelines.engagement.is_shared  <br>
</code>
</td>
<td>
<code>
any_feature <br>
client_log_event.tweet.has_consumer_video<br>
client_log_event.tweet.photo_count<br>
</code>
</td>
<td>
<code>
1.days.count<br>
30.minutes.count<br>
</code>
</td>
</tr>

<tr>
<td>
<code>
user (real_time)
</code>
</td>
<td>
<code>
timelines.engagement.is_fullscreen_video_dwelled<br>
timelines.engagement.is_fullscreen_video_dwelled_10_sec<br>
timelines.engagement.is_fullscreen_video_dwelled_20_sec<br>
timelines.engagement.is_fullscreen_video_dwelled_30_sec<br>
timelines.engagement.is_fullscreen_video_dwelled_5_sec<br>
timelines.engagement.is_profile_dwelled<br>
timelines.engagement.is_profile_dwelled_10_sec<br>
timelines.engagement.is_profile_dwelled_20_sec<br>
timelines.engagement.is_profile_dwelled_30_sec<br>
timelines.engagement.is_tweet_detail_dwelled<br>
timelines.engagement.is_tweet_detail_dwelled_15_sec<br>
timelines.engagement.is_tweet_detail_dwelled_25_sec<br>
timelines.engagement.is_tweet_detail_dwelled_30_sec<br>
timelines.engagement.is_tweet_detail_dwelled_8_sec<br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
1.days.count<br>
30.minutes.count<br>
</code>
</td>
</tr>

</table>
</details>

<details>
<summary><b><code>user_author_aggregate</code></b></summary>
These features aggregate over user-author pairs.
<br>
<table>
<tr>
<td>
<code>
user_author_v2
</code>
</td>
<td>
<code>
any_label<br>
recap.engagement.is_clicked<br>
recap.engagement.is_favorited<br>
recap.engagement.is_open_linked<br>
recap.engagement.is_photo_expanded<br>
recap.engagement.is_profile_clicked<br>
recap.engagement.is_replied<br>
recap.engagement.is_retweeted<br>
recap.engagement.is_video_playback_50<br>
</code>
</td>
<td>
<code>
engagement_features.in_network.favorites.count<br>
engagement_features.in_network.replies.count<br>
engagement_features.in_network.retweets.count<br>
recap.earlybird.fav_count_v2<br>
recap.earlybird.reply_count_v2<br>
recap.earlybird.retweet_count_v2<br>
recap.searchfeature.blender_score<br>
recap.searchfeature.fav_count<br>
recap.searchfeature.reply_count<br>
recap.searchfeature.retweet_count<br>
recap.searchfeature.text_score<br>
recap.tweetfeature.embeds_impression_count<br>
recap.tweetfeature.embeds_url_count<br>
recap.tweetfeature.has_card<br>
recap.tweetfeature.has_image<br>
recap.tweetfeature.has_link<br>
recap.tweetfeature.has_multiple_media<br>
recap.tweetfeature.has_news<br>
recap.tweetfeature.has_periscope<br>
recap.tweetfeature.has_pro_video<br>
recap.tweetfeature.has_trend<br>
recap.tweetfeature.has_video<br>
recap.tweetfeature.has_vine<br>
recap.tweetfeature.has_visible_link<br>
recap.tweetfeature.is_reply<br>
recap.tweetfeature.is_retweet<br>
recap.tweetfeature.num_mentions<br>
</code>
</td>
<td>
<code>
50.days.count<br>
50.days.sum<br>
</code>
</td>
</tr>
<tr>
<td>
<code>
user_author_v5
</code>
</td>
<td>
<code>
any_label<br>
recap.engagement.is_clicked<br>
recap.engagement.is_favorited<br>
recap.engagement.is_open_linked<br>
recap.engagement.is_photo_expanded<br>
recap.engagement.is_profile_clicked<br>
recap.engagement.is_replied<br>
recap.engagement.is_retweeted<br>
recap.engagement.is_video_playback_50<br>
</code>
</td>
<td>
<code>
any_feature<br>
timelines.earlybird.has_quote<br>
timelines.earlybird.label_abusive_flag<br>
timelines.earlybird.label_abusive_hi_rcl_flag<br>
timelines.earlybird.label_dup_content_flag<br>
timelines.earlybird.label_nsfw_hi_prc_flag<br>
timelines.earlybird.label_nsfw_hi_rcl_flag<br>
timelines.earlybird.label_spam_flag<br>
timelines.earlybird.label_spam_hi_rcl_flag<br>
</code>
</td>
<td>
<code>
50.days.count
</code>
</td>
</tr>
<tr>
<td>
<code>
user_author (tweetsource_v1 - <br>
These features are sourced from a different underlying dataset)
</code>
</td>
<td>
<code>
any_label<br>
recap.engagement.is_clicked<br>
recap.engagement.is_favorited<br>
recap.engagement.is_open_linked<br>
recap.engagement.is_photo_expanded<br>
recap.engagement.is_profile_clicked<br>
recap.engagement.is_replied<br>
recap.engagement.is_retweeted<br>
recap.engagement.is_video_playback_50<br>
</code>
</td>
<td>
<code>
any_feature<br>
tweetsource.tweet.media.num_tags<br>
tweetsource.tweet.media.video_duration<br>
tweetsource.tweet.text.has_question<br>
tweetsource.tweet.text.length<br>
</code>
</td>
<td>
<code>
50.days.count<br>
50.days.sum<br>
</code>
</td>
</tr>
<tr>
<td>
<code>
user_author (twitter_wide - <br>
These features are sourced from a different underlying dataset)
</code>
</td>
<td>
<code>
recap.engagement.is_favorited<br>
recap.engagement.is_replied<br>
recap.engagement.is_retweeted<br>
</code>
</td>
<td>
<code>
any_feature <br>
recap.tweetfeature.contains_media<br>
recap.tweetfeature.has_card<br>
recap.tweetfeature.has_hashtag<br>
recap.tweetfeature.has_link<br>
recap.tweetfeature.has_mention<br>
recap.tweetfeature.is_reply<br>
timelines.earlybird.has_quote<br>
</code>
</td>
<td>
<code>
50.days.count<br>
</code>
</td>
</tr>
<tr>
<td>
<code>
user_original_author (real_time)
</code>
</td>
<td>
<code>
timelines.engagement.is_shared<br>
</code>
</td>
<td>
<code>
any_feature<br>
</code>
</td>
<td>
<code>
1.days.count<br>
30.minutes.count<br>
</code>
</td>
</tr>

<tr>
<td>
<code>
user_original_author
</code>
</td>
<td>
<code>
recap.engagement.is_replied_reply_favorited_by_author<br>
recap.engagement.is_replied_reply_impressed_by_author<br>
recap.engagement.is_replied_reply_replied_by_author<br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
50.days.count
</code>
</td>
</tr>
<tr>
<td>
<code>
user_author (real_time, shared)
</code>
<td>
<code>
timelines.engagement.is_clicked<br>
timelines.engagement.is_dwelled<br>
timelines.engagement.is_favorited<br>
timelines.engagement.is_negative_feedback_union<br>
timelines.engagement.is_photo_expanded<br>
timelines.engagement.is_profile_clicked<br>
timelines.engagement.is_replied<br>
timelines.engagement.is_retweeted<br>
timelines.engagement.is_share_menu_clicked<br>
timelines.engagement.is_video_playback_50
</code>
</td>
<td>
<code>
any_feature
</code>
</td>
<td>
<code>
1.days.count<br>
30.minutes.count
</code>
</td>
</tr>
</table>
</details>



<details>
<summary><b><code>user_engager_aggregate</code></b></summary>
These features aggregate counts of user interaction with other engagers of tweets that the user interacts with.

For example, the <code>user_engager.recap.engagement.is_favorited.any_feature.50.days.count.sparse_top1</code> feature can be parsed as follows: 

For all tweets that a user Likes, accumulate a running count over 50 days where the number of engagement events for every other user who has engaged with the Tweet is accumulated. Engagement is defined as Like or reply. We now have a list of engagement counts for other users that have engaged with the Tweets that the user has Liked, and we take the top count as the feature value.  

<br>
<table>
<tr>
<td>
<code>
user_engager <br>
</code>
</td>
<td>
<code>
any_label <br>
recap.engagement.is_clicked <br>
recap.engagement.is_favorited <br>
recap.engagement.is_open_linked <br>
recap.engagement.is_photo_expanded <br>
recap.engagement.is_profile_clicked <br>
recap.engagement.is_replied <br>
recap.engagement.is_retweeted <br>
recap.engagement.is_video_playback_50 <br>
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
50.days.count.sparse_mean <br>
50.days.count.sparse_nonzero <br>
50.days.count.sparse_sum <br>
50.days.count.sparse_top1 <br>
50.days.count.sparse_top2 <br>
</code>
</td>
</tr>
</table>
</details>


<details>
<summary><b><code>user_inferred_topic_aggregate</code></b></summary>
These features aggregate short term and long term engagement between a user and tweets from our internally predicted inferred topic (whether or not the tweet is actually tagged to that topic).
<br>
<table>
<tr>
<td>
<code>
user_inferred_topic_v1
</code>
</td>
<td>
<code>
any_label <br>
recap.engagement.is_clicked <br>
recap.engagement.is_favorited <br>
recap.engagement.is_open_linked <br>
recap.engagement.is_photo_expanded <br>
recap.engagement.is_profile_clicked <br>
recap.engagement.is_replied <br>
recap.engagement.is_retweeted <br>
recap.engagement.is_video_playback_50
</code>
</td>
<td>
<code>
any_feature <br>
</code>
</td>
<td>
<code>
50.days.count.sparse_mean <br>
50.days.count.sparse_nonzero <br>
50.days.count.sparse_sum <br>
50.days.count.sparse_top1 <br>
50.days.count.sparse_top2 <br>
</code>
</td>
</tr>
<tr>
<td>
<code>
user_inferred_topic_v2
</code>
</td>
<td>
<code>
recap.engagement.is_clicked <br>
recap.engagement.is_favorited <br>
recap.engagement.is_open_linked <br>
recap.engagement.is_photo_expanded <br>
recap.engagement.is_profile_clicked <br>
recap.engagement.is_replied <br>
recap.engagement.is_retweeted <br>
recap.engagement.is_video_playback_50 <br>
</code>
</td>
<td>
<code>
engagement_features.in_network.favorites.count <br>
engagement_features.in_network.retweets.count <br>
recap.searchfeature.fav_count <br>
recap.tweetfeature.contains_media <br>
recap.tweetfeature.has_card <br>
recap.tweetfeature.has_image <br>
recap.tweetfeature.has_link <br>
recap.tweetfeature.has_news <br>
recap.tweetfeature.has_trend <br>
recap.tweetfeature.has_video <br>
recap.tweetfeature.is_reply <br>
recap.tweetfeature.is_retweet <br>
recap.tweetfeature.is_sensitive <br>
recap.tweetfeature.match_searcher_langs <br>
recap.tweetfeature.match_searcher_main_lang <br>
recap.tweetfeature.match_ui_lang <br>
recap.tweetfeature.mention_searcher <br>
recap.tweetfeature.reply_other <br>
recap.tweetfeature.reply_searcher <br>
recap.tweetfeature.retweet_other <br>
recap.tweetfeature.retweet_searcher <br>
tweetsource.tweet.media.aspect_ratio_den <br>
tweetsource.tweet.text.num_caps <br>
tweetsource.tweet.text.num_newlines <br>
tweetsource.v2.tweet.media.has_description <br
Download .txt
gitextract_800cdojr/

├── .github/
│   └── workflows/
│       └── main.yml
├── .gitignore
├── .pre-commit-config.yaml
├── COPYING
├── LICENSE.torchrec
├── README.md
├── common/
│   ├── __init__.py
│   ├── batch.py
│   ├── checkpointing/
│   │   ├── __init__.py
│   │   └── snapshot.py
│   ├── device.py
│   ├── filesystem/
│   │   ├── __init__.py
│   │   ├── test_infer_fs.py
│   │   └── util.py
│   ├── log_weights.py
│   ├── modules/
│   │   └── embedding/
│   │       ├── config.py
│   │       └── embedding.py
│   ├── run_training.py
│   ├── test_device.py
│   ├── testing_utils.py
│   ├── utils.py
│   └── wandb.py
├── core/
│   ├── __init__.py
│   ├── config/
│   │   ├── __init__.py
│   │   ├── base_config.py
│   │   ├── base_config_test.py
│   │   ├── config_load.py
│   │   ├── test_config_load.py
│   │   └── training.py
│   ├── custom_training_loop.py
│   ├── debug_training_loop.py
│   ├── loss_type.py
│   ├── losses.py
│   ├── metric_mixin.py
│   ├── metrics.py
│   ├── test_metrics.py
│   ├── test_train_pipeline.py
│   └── train_pipeline.py
├── images/
│   ├── init_venv.sh
│   └── requirements.txt
├── machines/
│   ├── environment.py
│   ├── get_env.py
│   ├── is_venv.py
│   └── list_ops.py
├── metrics/
│   ├── __init__.py
│   ├── aggregation.py
│   ├── auroc.py
│   └── rce.py
├── ml_logging/
│   ├── __init__.py
│   ├── absl_logging.py
│   ├── test_torch_logging.py
│   └── torch_logging.py
├── model.py
├── optimizers/
│   ├── __init__.py
│   ├── config.py
│   └── optimizer.py
├── projects/
│   ├── __init__.py
│   ├── home/
│   │   └── recap/
│   │       ├── FEATURES.md
│   │       ├── README.md
│   │       ├── __init__.py
│   │       ├── config/
│   │       │   ├── home_recap_2022/
│   │       │   │   └── segdense.json
│   │       │   └── local_prod.yaml
│   │       ├── config.py
│   │       ├── data/
│   │       │   ├── __init__.py
│   │       │   ├── config.py
│   │       │   ├── dataset.py
│   │       │   ├── generate_random_data.py
│   │       │   ├── preprocessors.py
│   │       │   ├── tfe_parsing.py
│   │       │   └── util.py
│   │       ├── embedding/
│   │       │   └── config.py
│   │       ├── main.py
│   │       ├── model/
│   │       │   ├── __init__.py
│   │       │   ├── config.py
│   │       │   ├── entrypoint.py
│   │       │   ├── feature_transform.py
│   │       │   ├── mask_net.py
│   │       │   ├── mlp.py
│   │       │   ├── model_and_loss.py
│   │       │   └── numeric_calibration.py
│   │       ├── optimizer/
│   │       │   ├── __init__.py
│   │       │   ├── config.py
│   │       │   └── optimizer.py
│   │       └── script/
│   │           ├── create_random_data.sh
│   │           └── run_local.sh
│   └── twhin/
│       ├── README.md
│       ├── config/
│       │   └── local.yaml
│       ├── config.py
│       ├── data/
│       │   ├── config.py
│       │   ├── data.py
│       │   ├── edges.py
│       │   ├── test_data.py
│       │   └── test_edges.py
│       ├── machines.yaml
│       ├── metrics.py
│       ├── models/
│       │   ├── config.py
│       │   ├── models.py
│       │   └── test_models.py
│       ├── optimizer.py
│       ├── run.py
│       ├── scripts/
│       │   ├── docker_run.sh
│       │   └── run_in_docker.sh
│       └── test_optimizer.py
├── pyproject.toml
├── reader/
│   ├── __init__.py
│   ├── dataset.py
│   ├── dds.py
│   ├── test_dataset.py
│   ├── test_utils.py
│   └── utils.py
└── tools/
    └── pq.py
Download .txt
SYMBOL INDEX (456 symbols across 77 files)

FILE: common/batch.py
  class BatchBase (line 14) | class BatchBase(Pipelineable, abc.ABC):
    method as_dict (line 16) | def as_dict(self) -> Dict:
    method to (line 19) | def to(self, device: torch.device, non_blocking: bool = False):
    method record_stream (line 25) | def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
    method pin_memory (line 29) | def pin_memory(self):
    method __repr__ (line 35) | def __repr__(self) -> str:
    method batch_size (line 42) | def batch_size(self) -> int:
  class DataclassBatch (line 53) | class DataclassBatch(BatchBase):
    method feature_names (line 55) | def feature_names(cls):
    method as_dict (line 58) | def as_dict(self):
    method from_schema (line 66) | def from_schema(name: str, schema):
    method from_fields (line 75) | def from_fields(name: str, fields: dict):
  class DictionaryBatch (line 83) | class DictionaryBatch(BatchBase, dict):
    method as_dict (line 84) | def as_dict(self) -> Dict:

FILE: common/checkpointing/snapshot.py
  class Snapshot (line 15) | class Snapshot:
    method __init__ (line 22) | def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
    method step (line 28) | def step(self):
    method step (line 32) | def step(self, step: int) -> None:
    method walltime (line 36) | def walltime(self):
    method walltime (line 40) | def walltime(self, walltime: float) -> None:
    method save (line 43) | def save(self, global_step: int) -> "PendingSnapshot":
    method restore (line 60) | def restore(self, checkpoint: str) -> None:
    method get_torch_snapshot (line 80) | def get_torch_snapshot(
    method load_snapshot_to_weight (line 97) | def load_snapshot_to_weight(
  function _eval_subdir (line 123) | def _eval_subdir(checkpoint_path: str) -> str:
  function _eval_done_path (line 127) | def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
  function is_done_eval (line 131) | def is_done_eval(checkpoint_path: str, eval_partition: str):
  function mark_done_eval (line 135) | def mark_done_eval(checkpoint_path: str, eval_partition: str):
  function step_from_checkpoint (line 139) | def step_from_checkpoint(checkpoint: str) -> int:
  function checkpoints_iterator (line 143) | def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, time...
  function get_checkpoint (line 177) | def get_checkpoint(
  function get_checkpoints (line 211) | def get_checkpoints(save_dir: str) -> List[str]:
  function wait_for_evaluators (line 229) | def wait_for_evaluators(

FILE: common/device.py
  function maybe_setup_tensorflow (line 7) | def maybe_setup_tensorflow():
  function setup_and_get_device (line 16) | def setup_and_get_device(tf_ok: bool = True) -> torch.device:

FILE: common/filesystem/test_infer_fs.py
  function test_infer_fs (line 8) | def test_infer_fs():

FILE: common/filesystem/util.py
  function infer_fs (line 10) | def infer_fs(path: str):
  function is_local_fs (line 20) | def is_local_fs(fs):
  function is_gcs_fs (line 24) | def is_gcs_fs(fs):

FILE: common/log_weights.py
  function weights_to_log (line 11) | def weights_to_log(
  function log_ebc_norms (line 47) | def log_ebc_norms(

FILE: common/modules/embedding/config.py
  class DataType (line 10) | class DataType(str, Enum):
  class EmbeddingSnapshot (line 15) | class EmbeddingSnapshot(base_config.BaseConfig):
  class EmbeddingBagConfig (line 26) | class EmbeddingBagConfig(base_config.BaseConfig):
  class LargeEmbeddingsConfig (line 42) | class LargeEmbeddingsConfig(base_config.BaseConfig):
  class Mode (line 54) | class Mode(str, Enum):

FILE: common/modules/embedding/embedding.py
  class LargeEmbeddings (line 13) | class LargeEmbeddings(nn.Module):
    method __init__ (line 14) | def __init__(
    method forward (line 51) | def forward(

FILE: common/run_training.py
  function is_distributed_worker (line 13) | def is_distributed_worker():
  function maybe_run_training (line 19) | def maybe_run_training(

FILE: common/test_device.py
  function test_device (line 10) | def test_device():

FILE: common/testing_utils.py
  function mock_pg (line 21) | def mock_pg():

FILE: common/utils.py
  function _read_file (line 14) | def _read_file(f):
  function setup_configuration (line 19) | def setup_configuration(

FILE: common/wandb.py
  class WandbConfig (line 8) | class WandbConfig(base_config.BaseConfig):

FILE: core/config/base_config.py
  class BaseConfig (line 10) | class BaseConfig(pydantic.BaseModel):
    class Config (line 30) | class Config:
    method _field_data_map (line 37) | def _field_data_map(cls, field_data_name):
    method _one_of_check (line 47) | def _one_of_check(cls, values):
    method _at_most_one_of_check (line 56) | def _at_most_one_of_check(cls, values):
    method pretty_print (line 64) | def pretty_print(self) -> str:

FILE: core/config/base_config_test.py
  class BaseConfigTest (line 8) | class BaseConfigTest(TestCase):
    method test_extra_forbidden (line 9) | def test_extra_forbidden(self):
    method test_one_of (line 17) | def test_one_of(self):
    method test_at_most_one_of (line 29) | def test_at_most_one_of(self):

FILE: core/config/config_load.py
  function load_config_from_yaml (line 10) | def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):

FILE: core/config/test_config_load.py
  class _PointlessConfig (line 10) | class _PointlessConfig(BaseConfig):
  function test_load_config_from_yaml (line 15) | def test_load_config_from_yaml(tmp_path):

FILE: core/config/training.py
  class RuntimeConfig (line 11) | class RuntimeConfig(base_config.BaseConfig):
  class TrainingConfig (line 19) | class TrainingConfig(base_config.BaseConfig):

FILE: core/custom_training_loop.py
  function get_new_iterator (line 29) | def get_new_iterator(iterable: Iterable):
  function _get_step_fn (line 48) | def _get_step_fn(pipeline, data_iterator, training: bool):
  function _run_evaluation (line 64) | def _run_evaluation(
  function train (line 92) | def train(
  function log_eval_results (line 259) | def log_eval_results(
  function only_evaluate (line 274) | def only_evaluate(

FILE: core/debug_training_loop.py
  function train (line 20) | def train(

FILE: core/loss_type.py
  class LossType (line 5) | class LossType(str, Enum):

FILE: core/losses.py
  function _maybe_warn (line 11) | def _maybe_warn(reduction: str):
  function build_loss (line 23) | def build_loss(
  function get_global_loss_detached (line 36) | def get_global_loss_detached(local_loss, reduction="mean"):
  function build_multi_task_loss (line 62) | def build_multi_task_loss(

FILE: core/metric_mixin.py
  class MetricMixin (line 36) | class MetricMixin:
    method transform (line 38) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
    method update (line 41) | def update(self, outputs: Dict[str, torch.Tensor]):
  class TaskMixin (line 50) | class TaskMixin:
    method __init__ (line 51) | def __init__(self, task_idx: int = -1, **kwargs):
  class StratifyMixin (line 56) | class StratifyMixin:
    method __init__ (line 57) | def __init__(
    method maybe_apply_stratification (line 65) | def maybe_apply_stratification(
  function prepend_transform (line 86) | def prepend_transform(base_metric: torchmetrics.Metric, transform: Calla...

FILE: core/metrics.py
  function probs_and_labels (line 14) | def probs_and_labels(
  class Count (line 29) | class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
    method transform (line 30) | def transform(self, outputs):
  class Ctr (line 38) | class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
    method transform (line 39) | def transform(self, outputs):
  class Pctr (line 47) | class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
    method transform (line 48) | def transform(self, outputs):
  class Precision (line 56) | class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
    method transform (line 57) | def transform(self, outputs):
  class Recall (line 62) | class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
    method transform (line 63) | def transform(self, outputs):
  class TorchMetricsRocauc (line 68) | class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
    method transform (line 69) | def transform(self, outputs):
  class Auc (line 74) | class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
    method __init__ (line 80) | def __init__(self, num_samples, **kwargs):
    method transform (line 84) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
  class PosRanks (line 95) | class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
    method __init__ (line 102) | def __init__(self, **kwargs):
    method transform (line 105) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
  class ReciprocalRank (line 113) | class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
    method __init__ (line 120) | def __init__(self, **kwargs):
    method transform (line 123) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
  class HitAtK (line 131) | class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
    method __init__ (line 139) | def __init__(self, k: int, **kwargs):
    method transform (line 143) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:

FILE: core/test_metrics.py
  class MockStratifierConfig (line 11) | class MockStratifierConfig:
  class Count (line 17) | class Count(MetricMixin, SumMetric):
    method transform (line 18) | def transform(self, outputs):
  function test_count_metric (line 25) | def test_count_metric():
  function test_collections (line 38) | def test_collections():
  function test_task_dependent_ctr (line 53) | def test_task_dependent_ctr():
  function test_stratified_ctr (line 71) | def test_stratified_ctr():
  function test_auc (line 116) | def test_auc():
  function test_pos_rank (line 133) | def test_pos_rank():
  function test_reciprocal_rank (line 149) | def test_reciprocal_rank():
  function test_hit_k (line 165) | def test_hit_k():

FILE: core/test_train_pipeline.py
  class MockDataclassBatch (line 13) | class MockDataclassBatch(DataclassBatch):
  class MockModule (line 18) | class MockModule(torch.nn.Module):
    method __init__ (line 19) | def __init__(self) -> None:
    method forward (line 24) | def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, to...
  function create_batch (line 30) | def create_batch(bsz: int):
  function test_sparse_pipeline (line 37) | def test_sparse_pipeline():
  function test_amp (line 67) | def test_amp():

FILE: core/train_pipeline.py
  class TrainPipeline (line 41) | class TrainPipeline(abc.ABC, Generic[In, Out]):
    method progress (line 43) | def progress(self, dataloader_iter: Iterator[In]) -> Out:
  function _to_device (line 47) | def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
  function _wait_for_batch (line 54) | def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Strea...
  class TrainPipelineBase (line 73) | class TrainPipelineBase(TrainPipeline[In, Out]):
    method __init__ (line 81) | def __init__(
    method _connect (line 96) | def _connect(self, dataloader_iter: Iterator[In]) -> None:
    method progress (line 103) | def progress(self, dataloader_iter: Iterator[In]) -> Out:
  class Tracer (line 141) | class Tracer(torch.fx.Tracer):
    method __init__ (line 148) | def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
    method is_leaf_module (line 152) | def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: st...
  class TrainPipelineContext (line 159) | class TrainPipelineContext:
  class ArgInfo (line 168) | class ArgInfo:
  class PipelinedForward (line 179) | class PipelinedForward:
    method __init__ (line 180) | def __init__(
    method __call__ (line 195) | def __call__(self, *input, **kwargs) -> Awaitable:
    method name (line 232) | def name(self) -> str:
    method args (line 236) | def args(self) -> List[ArgInfo]:
  function _start_data_dist (line 240) | def _start_data_dist(
  function _get_node_args_helper (line 282) | def _get_node_args_helper(
  function _get_node_args (line 332) | def _get_node_args(
  function _get_unsharded_module_names_helper (line 349) | def _get_unsharded_module_names_helper(
  function _get_unsharded_module_names (line 376) | def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
  function _rewrite_model (line 390) | def _rewrite_model(  # noqa C901
  class TrainPipelineSparseDist (line 442) | class TrainPipelineSparseDist(TrainPipeline[In, Out]):
    method __init__ (line 462) | def __init__(
    method _connect (line 506) | def _connect(self, dataloader_iter: Iterator[In]) -> None:
    method progress (line 525) | def progress(self, dataloader_iter: Iterator[In]) -> Out:
    method _sync_pipeline (line 618) | def _sync_pipeline(self) -> None:

FILE: machines/environment.py
  function on_kf (line 11) | def on_kf():
  function has_readers (line 15) | def has_readers():
  function get_task_type (line 22) | def get_task_type():
  function is_chief (line 28) | def is_chief() -> bool:
  function is_reader (line 32) | def is_reader() -> bool:
  function is_dispatcher (line 36) | def is_dispatcher() -> bool:
  function get_task_index (line 40) | def get_task_index():
  function get_reader_port (line 48) | def get_reader_port():
  function get_dds (line 54) | def get_dds():
  function get_dds_dispatcher_address (line 64) | def get_dds_dispatcher_address():
  function get_dds_worker_address (line 75) | def get_dds_worker_address():
  function get_num_readers (line 87) | def get_num_readers():
  function get_flight_server_addresses (line 96) | def get_flight_server_addresses():
  function get_dds_journaling_dir (line 107) | def get_dds_journaling_dir():

FILE: machines/get_env.py
  function main (line 10) | def main(argv):

FILE: machines/is_venv.py
  function is_venv (line 11) | def is_venv():
  function _main (line 16) | def _main():

FILE: machines/list_ops.py
  function main (line 32) | def main(argv):

FILE: metrics/aggregation.py
  function update_mean (line 10) | def update_mean(
  function stable_mean_dist_reduce_fn (line 38) | def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
  class StableMean (line 56) | class StableMean(torchmetrics.Metric):
    method __init__ (line 65) | def __init__(self, **kwargs):
    method update (line 77) | def update(self, value: torch.Tensor, weight: Union[float, torch.Tenso...
    method compute (line 93) | def compute(self) -> torch.Tensor:

FILE: metrics/auroc.py
  function _compute_helper (line 13) | def _compute_helper(
  class AUROCWithMWU (line 53) | class AUROCWithMWU(torchmetrics.Metric):
    method __init__ (line 63) | def __init__(self, label_threshold: float = 0.5, raise_missing_class: ...
    method update (line 81) | def update(
    method compute (line 101) | def compute(self) -> torch.Tensor:

FILE: metrics/rce.py
  function _smooth (line 14) | def _smooth(
  function _binary_cross_entropy_with_clipping (line 27) | def _binary_cross_entropy_with_clipping(
  class RCE (line 54) | class RCE(torchmetrics.Metric):
    method __init__ (line 125) | def __init__(
    method update (line 152) | def update(
    method compute (line 169) | def compute(self) -> torch.Tensor:
    method reset (line 183) | def reset(self):
    method forward (line 191) | def forward(self, *args, **kwargs):
  class NRCE (line 205) | class NRCE(RCE):
    method __init__ (line 226) | def __init__(
    method update (line 246) | def update(
    method reset (line 275) | def reset(self):

FILE: ml_logging/absl_logging.py
  function setup_absl_logging (line 16) | def setup_absl_logging():

FILE: ml_logging/test_torch_logging.py
  class Testtlogging (line 6) | class Testtlogging(unittest.TestCase):
    method test_warn_once (line 7) | def test_warn_once(self):

FILE: ml_logging/torch_logging.py
  function rank_specific (line 20) | def rank_specific(logger):

FILE: model.py
  class ModelAndLoss (line 12) | class ModelAndLoss(torch.nn.Module):
    method __init__ (line 15) | def __init__(
    method forward (line 29) | def forward(self, batch: "RecapBatch"):  # type: ignore[name-defined]
  function maybe_shard_model (line 53) | def maybe_shard_model(
  function log_sharded_tensor_content (line 76) | def log_sharded_tensor_content(weight_name: str, table_name: str, weight...

FILE: optimizers/config.py
  class PiecewiseConstant (line 10) | class PiecewiseConstant(base_config.BaseConfig):
  class LinearRampToConstant (line 15) | class LinearRampToConstant(base_config.BaseConfig):
  class LinearRampToCosine (line 22) | class LinearRampToCosine(base_config.BaseConfig):
  class LearningRate (line 33) | class LearningRate(base_config.BaseConfig):
  class OptimizerAlgorithmConfig (line 40) | class OptimizerAlgorithmConfig(base_config.BaseConfig):
  class AdamConfig (line 47) | class AdamConfig(OptimizerAlgorithmConfig):
  class SgdConfig (line 54) | class SgdConfig(OptimizerAlgorithmConfig):
  class AdagradConfig (line 59) | class AdagradConfig(OptimizerAlgorithmConfig):
  class OptimizerConfig (line 64) | class OptimizerConfig(base_config.BaseConfig):
  function get_optimizer_algorithm_config (line 74) | def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):

FILE: optimizers/optimizer.py
  function compute_lr (line 16) | def compute_lr(lr_config, step):
  class LRShim (line 48) | class LRShim(_LRScheduler):
    method __init__ (line 55) | def __init__(
    method get_lr (line 74) | def get_lr(self):
    method _get_closed_form_lr (line 82) | def _get_closed_form_lr(self):
  function get_optimizer_class (line 86) | def get_optimizer_class(optimizer_config: OptimizerConfig):
  function build_optimizer (line 95) | def build_optimizer(

FILE: projects/home/recap/config.py
  class TrainingConfig (line 11) | class TrainingConfig(config_mod.BaseConfig):
  class RecapConfig (line 34) | class RecapConfig(config_mod.BaseConfig):
  class JobMode (line 49) | class JobMode(str, Enum):

FILE: projects/home/recap/data/config.py
  class ExplicitDateInputs (line 10) | class ExplicitDateInputs(base_config.BaseConfig):
  class ExplicitDatetimeInputs (line 21) | class ExplicitDatetimeInputs(base_config.BaseConfig):
  class DdsCompressionOption (line 32) | class DdsCompressionOption(str, Enum):
  class DatasetConfig (line 38) | class DatasetConfig(base_config.BaseConfig):
  class TruncateAndSlice (line 82) | class TruncateAndSlice(base_config.BaseConfig):
  class DataType (line 99) | class DataType(str, Enum):
  class DownCast (line 109) | class DownCast(base_config.BaseConfig):
  class TaskData (line 116) | class TaskData(base_config.BaseConfig):
  class SegDenseSchema (line 127) | class SegDenseSchema(base_config.BaseConfig):
  class RectifyLabels (line 142) | class RectifyLabels(base_config.BaseConfig):
  class ExtractFeaturesRow (line 157) | class ExtractFeaturesRow(base_config.BaseConfig):
  class ExtractFeatures (line 172) | class ExtractFeatures(base_config.BaseConfig):
  class DownsampleNegatives (line 179) | class DownsampleNegatives(base_config.BaseConfig):
  class Preprocess (line 194) | class Preprocess(base_config.BaseConfig):
  class Sampler (line 208) | class Sampler(base_config.BaseConfig):
  class RecapDataConfig (line 221) | class RecapDataConfig(DatasetConfig):
    method _validate_evaluation_tasks (line 241) | def _validate_evaluation_tasks(cls, values):

FILE: projects/home/recap/data/dataset.py
  class RecapBatch (line 22) | class RecapBatch(DataclassBatch):
    method __post_init__ (line 35) | def __post_init__(self):
  function to_batch (line 43) | def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> Rec...
  function _chain (line 91) | def _chain(param, f1, f2):
  function _add_weights (line 103) | def _add_weights(inputs, tasks: Dict[str, TaskData]):
  function get_datetimes (line 130) | def get_datetimes(explicit_datetime_inputs):
  function get_explicit_datetime_inputs_files (line 143) | def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
  function _map_output_for_inference (line 183) | def _map_output_for_inference(
  function _map_output_for_train_eval (line 198) | def _map_output_for_train_eval(
  function _add_weights_based_on_sampling_rates (line 216) | def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskDa...
  class RecapDataset (line 242) | class RecapDataset(torch.utils.data.IterableDataset):
    method __init__ (line 243) | def __init__(
    method _init_tensor_spec (line 304) | def _init_tensor_spec(self):
    method _create_tf_dataset (line 315) | def _create_tf_dataset(self):
    method _create_base_tf_dataset (line 373) | def _create_base_tf_dataset(self, batch_size: int):
    method _gen (line 469) | def _gen(self):
    method to_dataloader (line 473) | def to_dataloader(self) -> Dict[str, torch.Tensor]:
    method __iter__ (line 476) | def __iter__(self):

FILE: projects/home/recap/data/generate_random_data.py
  function _generate_random_example (line 17) | def _generate_random_example(
  function _float_feature (line 35) | def _float_feature(value):
  function _int64_feature (line 39) | def _int64_feature(value):
  function _serialize_example (line 43) | def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
  function generate_data (line 53) | def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
  function _generate_data_main (line 70) | def _generate_data_main(unused_argv):

FILE: projects/home/recap/data/preprocessors.py
  class TruncateAndSlice (line 11) | class TruncateAndSlice(tf.keras.Model):
    method __init__ (line 14) | def __init__(self, truncate_and_slice_config):
    method call (line 34) | def call(self, inputs, training=None, mask=None):
  class DownCast (line 53) | class DownCast(tf.keras.Model):
    method __init__ (line 59) | def __init__(self, downcast_config):
    method call (line 67) | def call(self, inputs, training=None, mask=None):
  class RectifyLabels (line 80) | class RectifyLabels(tf.keras.Model):
    method __init__ (line 83) | def __init__(self, rectify_label_config):
    method call (line 88) | def call(self, inputs, training=None, mask=None):
  class ExtractFeatures (line 104) | class ExtractFeatures(tf.keras.Model):
    method __init__ (line 107) | def __init__(self, extract_features_config):
    method call (line 111) | def call(self, inputs, training=None, mask=None):
  class DownsampleNegatives (line 119) | class DownsampleNegatives(tf.keras.Model):
    method __init__ (line 130) | def __init__(self, downsample_negatives_config):
    method call (line 134) | def call(self, inputs, training=None, mask=None):
  function build_preprocess (line 170) | def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):

FILE: projects/home/recap/data/tfe_parsing.py
  function create_tf_example_schema (line 14) | def create_tf_example_schema(
  function make_mantissa_mask (line 60) | def make_mantissa_mask(mask_length: int) -> tf.Tensor:
  function mask_mantissa (line 65) | def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
  function parse_tf_example (line 71) | def parse_tf_example(
  function get_seg_dense_parse_fn (line 108) | def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):

FILE: projects/home/recap/data/util.py
  function keyed_tensor_from_tensors_dict (line 8) | def keyed_tensor_from_tensors_dict(
  function _compute_jagged_tensor_from_tensor (line 30) | def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[to...
  function jagged_tensor_from_tensor (line 41) | def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedT...
  function keyed_jagged_tensor_from_tensors_dict (line 55) | def keyed_jagged_tensor_from_tensors_dict(
  function _tf_to_numpy (line 93) | def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
  function _dense_tf_to_torch (line 97) | def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Ten...
  function sparse_or_dense_tf_to_torch (line 109) | def sparse_or_dense_tf_to_torch(

FILE: projects/home/recap/embedding/config.py
  class EmbeddingSnapshot (line 8) | class EmbeddingSnapshot(base_config.BaseConfig):
  class EmbeddingBagConfig (line 20) | class EmbeddingBagConfig(base_config.BaseConfig):
  class EmbeddingOptimizerConfig (line 32) | class EmbeddingOptimizerConfig(base_config.BaseConfig):
  class LargeEmbeddingsConfig (line 41) | class LargeEmbeddingsConfig(base_config.BaseConfig):
  class StratifierConfig (line 54) | class StratifierConfig(base_config.BaseConfig):
  class SmallEmbeddingBagConfig (line 60) | class SmallEmbeddingBagConfig(base_config.BaseConfig):
  class SmallEmbeddingBagConfig (line 69) | class SmallEmbeddingBagConfig(base_config.BaseConfig):
  class SmallEmbeddingsConfig (line 78) | class SmallEmbeddingsConfig(base_config.BaseConfig):

FILE: projects/home/recap/main.py
  function run (line 36) | def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):

FILE: projects/home/recap/model/config.py
  class DropoutConfig (line 12) | class DropoutConfig(base_config.BaseConfig):
  class LayerNormConfig (line 20) | class LayerNormConfig(base_config.BaseConfig):
  class BatchNormConfig (line 31) | class BatchNormConfig(base_config.BaseConfig):
  class DenseLayerConfig (line 42) | class DenseLayerConfig(base_config.BaseConfig):
  class MlpConfig (line 47) | class MlpConfig(base_config.BaseConfig):
  class BatchNormConfig (line 54) | class BatchNormConfig(base_config.BaseConfig):
  class DoubleNormLogConfig (line 63) | class DoubleNormLogConfig(base_config.BaseConfig):
  class Log1pAbsConfig (line 71) | class Log1pAbsConfig(base_config.BaseConfig):
  class ClipLog1pAbsConfig (line 75) | class ClipLog1pAbsConfig(base_config.BaseConfig):
  class ZScoreLogConfig (line 81) | class ZScoreLogConfig(base_config.BaseConfig):
  class FeaturizationConfig (line 101) | class FeaturizationConfig(base_config.BaseConfig):
  class DropoutConfig (line 113) | class DropoutConfig(base_config.BaseConfig):
  class MlpConfig (line 121) | class MlpConfig(base_config.BaseConfig):
  class DcnConfig (line 134) | class DcnConfig(base_config.BaseConfig):
  class MaskBlockConfig (line 150) | class MaskBlockConfig(base_config.BaseConfig):
  class MaskNetConfig (line 161) | class MaskNetConfig(base_config.BaseConfig):
  class PositionDebiasConfig (line 167) | class PositionDebiasConfig(base_config.BaseConfig):
  class AffineMap (line 185) | class AffineMap(base_config.BaseConfig):
  class DLRMConfig (line 192) | class DLRMConfig(base_config.BaseConfig):
  class TaskModel (line 200) | class TaskModel(base_config.BaseConfig):
  class MultiTaskType (line 215) | class MultiTaskType(str, enum.Enum):
  class ModelConfig (line 221) | class ModelConfig(base_config.BaseConfig):
    method _validate_mtl (line 249) | def _validate_mtl(cls, values):

FILE: projects/home/recap/model/entrypoint.py
  function sanitize (line 20) | def sanitize(task_name):
  function unsanitize (line 24) | def unsanitize(sanitized_task_name):
  function _build_single_task_model (line 28) | def _build_single_task_model(task: model_config_mod.TaskModel, input_sha...
  class MultiTaskRankingModel (line 40) | class MultiTaskRankingModel(torch.nn.Module):
    method __init__ (line 43) | def __init__(
    method forward (line 159) | def forward(
  function create_ranking_model (line 264) | def create_ranking_model(

FILE: projects/home/recap/model/feature_transform.py
  function log_transform (line 13) | def log_transform(x: torch.Tensor) -> torch.Tensor:
  class BatchNorm (line 18) | class BatchNorm(torch.nn.Module):
    method __init__ (line 19) | def __init__(self, num_features: int, config: BatchNormConfig):
    method forward (line 23) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class LayerNorm (line 27) | class LayerNorm(torch.nn.Module):
    method __init__ (line 28) | def __init__(self, normalized_shape: Union[int, Sequence[int]], config...
    method forward (line 40) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Log1pAbs (line 44) | class Log1pAbs(torch.nn.Module):
    method __init__ (line 45) | def __init__(self):
    method forward (line 48) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class InputNonFinite (line 52) | class InputNonFinite(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, fill_value: float = 0):
    method forward (line 60) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Clamp (line 64) | class Clamp(torch.nn.Module):
    method __init__ (line 65) | def __init__(self, min_value: float, max_value: float):
    method forward (line 76) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class DoubleNormLog (line 80) | class DoubleNormLog(torch.nn.Module):
    method __init__ (line 83) | def __init__(
    method forward (line 108) | def forward(
  function build_features_preprocessor (line 118) | def build_features_preprocessor(

FILE: projects/home/recap/model/mask_net.py
  function _init_weights (line 8) | def _init_weights(module):
  class MaskBlock (line 14) | class MaskBlock(torch.nn.Module):
    method __init__ (line 15) | def __init__(
    method forward (line 44) | def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
  class MaskNet (line 51) | class MaskNet(torch.nn.Module):
    method __init__ (line 52) | def __init__(self, mask_net_config: config.MaskNetConfig, in_features:...
    method forward (line 79) | def forward(self, inputs: torch.Tensor):

FILE: projects/home/recap/model/mlp.py
  function _init_weights (line 9) | def _init_weights(module):
  class Mlp (line 15) | class Mlp(torch.nn.Module):
    method __init__ (line 16) | def __init__(self, in_features: int, mlp_config: MlpConfig):
    method forward (line 44) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method shared_size (line 53) | def shared_size(self):
    method out_features (line 57) | def out_features(self):

FILE: projects/home/recap/model/model_and_loss.py
  class ModelAndLoss (line 7) | class ModelAndLoss(torch.nn.Module):
    method __init__ (line 8) | def __init__(
    method forward (line 25) | def forward(self, batch: "RecapBatch"):  # type: ignore[name-defined]

FILE: projects/home/recap/model/numeric_calibration.py
  class NumericCalibration (line 4) | class NumericCalibration(torch.nn.Module):
    method __init__ (line 5) | def __init__(
    method forward (line 18) | def forward(self, probs: torch.Tensor):

FILE: projects/home/recap/optimizer/config.py
  class RecapAdamConfig (line 11) | class RecapAdamConfig(base_config.BaseConfig):
  class MultiTaskLearningRates (line 17) | class MultiTaskLearningRates(base_config.BaseConfig):
  class RecapOptimizerConfig (line 27) | class RecapOptimizerConfig(base_config.BaseConfig):

FILE: projects/home/recap/optimizer/optimizer.py
  class RecapLRShim (line 25) | class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
    method __init__ (line 33) | def __init__(
    method get_lr (line 59) | def get_lr(self):
    method _get_closed_form_lr (line 67) | def _get_closed_form_lr(self):
  function build_optimizer (line 78) | def build_optimizer(

FILE: projects/twhin/config.py
  class TwhinConfig (line 9) | class TwhinConfig(base_config.BaseConfig):

FILE: projects/twhin/data/config.py
  class TwhinDataConfig (line 6) | class TwhinDataConfig(base_config.BaseConfig):

FILE: projects/twhin/data/data.py
  function create_dataset (line 6) | def create_dataset(data_config: TwhinDataConfig, model_config: TwhinMode...

FILE: projects/twhin/data/edges.py
  class EdgeBatch (line 17) | class EdgeBatch(DataclassBatch):
  class EdgesDataset (line 24) | class EdgesDataset(Dataset):
    method __init__ (line 27) | def __init__(
    method pa_to_batch (line 58) | def pa_to_batch(self, batch: pa.RecordBatch):
    method _to_kjt (line 72) | def _to_kjt(
    method to_batches (line 149) | def to_batches(self):

FILE: projects/twhin/data/test_data.py
  function test_create_dataset (line 5) | def test_create_dataset():

FILE: projects/twhin/data/test_edges.py
  function test_gen (line 25) | def test_gen():

FILE: projects/twhin/metrics.py
  function create_metrics (line 7) | def create_metrics(

FILE: projects/twhin/models/config.py
  class TwhinEmbeddingsConfig (line 12) | class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
    method embedding_dims_match (line 14) | def embedding_dims_match(cls, tables):
  class Operator (line 23) | class Operator(str, enum.Enum):
  class Relation (line 27) | class Relation(pydantic.BaseModel):
  class TwhinModelConfig (line 44) | class TwhinModelConfig(base_config.BaseConfig):
    method valid_node_types (line 50) | def valid_node_types(cls, relation, values, **kwargs):

FILE: projects/twhin/models/models.py
  class TwhinModel (line 16) | class TwhinModel(nn.Module):
    method __init__ (line 17) | def __init__(self, model_config: TwhinModelConfig, data_config: TwhinD...
    method forward (line 33) | def forward(self, batch: EdgeBatch):
  function apply_optimizers (line 100) | def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
  class TwhinModelAndLoss (line 118) | class TwhinModelAndLoss(torch.nn.Module):
    method __init__ (line 119) | def __init__(
    method forward (line 138) | def forward(self, batch: "RecapBatch"):  # type: ignore[name-defined]

FILE: projects/twhin/models/test_models.py
  function twhin_model_config (line 20) | def twhin_model_config() -> TwhinModelConfig:
  function twhin_data_config (line 54) | def twhin_data_config() -> TwhinDataConfig:
  function test_twhin_model (line 67) | def test_twhin_model():
  function test_unequal_dims (line 86) | def test_unequal_dims():

FILE: projects/twhin/optimizer.py
  function _lr_from_config (line 17) | def _lr_from_config(optimizer_config):
  function build_optimizer (line 26) | def build_optimizer(model: TwhinModel, config: TwhinModelConfig):

FILE: projects/twhin/run.py
  function run (line 36) | def run(
  function main (line 82) | def main(argv):

FILE: projects/twhin/test_optimizer.py
  function test_twhin_optimizer (line 15) | def test_twhin_optimizer():

FILE: reader/dataset.py
  class _Reader (line 27) | class _Reader(pa.flight.FlightServerBase):
    method __init__ (line 30) | def __init__(self, location: str, ds: "Dataset"):
    method do_get (line 35) | def do_get(self, _, __):
  class Dataset (line 48) | class Dataset(torch.utils.data.IterableDataset):
    method __init__ (line 51) | def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
    method _validate_columns (line 66) | def _validate_columns(self):
    method serve (line 72) | def serve(self):
    method _create_dataset (line 76) | def _create_dataset(self):
    method to_batches (line 84) | def to_batches(self):
    method pa_to_batch (line 102) | def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
    method dataloader (line 105) | def dataloader(self, remote: bool = False):
  function get_readers (line 119) | def get_readers(num_readers_per_worker: int):

FILE: reader/dds.py
  function maybe_start_dataset_service (line 23) | def maybe_start_dataset_service():
  function register_dataset (line 59) | def register_dataset(
  function distribute_from_dataset_id (line 78) | def distribute_from_dataset_id(
  function maybe_distribute_dataset (line 99) | def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:

FILE: reader/test_dataset.py
  function create_dataset (line 14) | def create_dataset(tmpdir):
  function test_dataset (line 36) | def test_dataset(tmpdir):
  function test_distributed_dataset (line 48) | def test_distributed_dataset(tmpdir):

FILE: reader/test_utils.py
  function test_rr (line 4) | def test_rr():

FILE: reader/utils.py
  function roundrobin (line 13) | def roundrobin(*iterables):
  function speed_check (line 37) | def speed_check(data_loader, max_steps: int, frequency: int, peek: Optio...
  function pa_to_torch (line 59) | def pa_to_torch(array: pa.array) -> torch.Tensor:
  function create_default_pa_to_batch (line 63) | def create_default_pa_to_batch(schema) -> DataclassBatch:

FILE: tools/pq.py
  function _create_dataset (line 40) | def _create_dataset(path: str):
  class PqReader (line 46) | class PqReader:
    method __init__ (line 47) | def __init__(
    method __iter__ (line 55) | def __iter__(self):
    method _head (line 64) | def _head(self):
    method bytes_per_row (line 73) | def bytes_per_row(self) -> int:
    method schema (line 83) | def schema(self):
    method head (line 86) | def head(self):
    method distinct (line 90) | def distinct(self):
Condensed preview — 111 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (407K chars).
[
  {
    "path": ".github/workflows/main.yml",
    "chars": 1159,
    "preview": "name: Python package\n\non: [push]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-v"
  },
  {
    "path": ".gitignore",
    "chars": 352,
    "preview": "# Mac\n.DS_Store\n\n# Vim\n*.py.swp\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n\n# C extensions\n*.so\n\n# "
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 479,
    "preview": "repos:\n-   repo: https://github.com/pausan/cblack\n    rev: release-22.3.0\n    hooks:\n    - id: cblack\n      name: cblack"
  },
  {
    "path": "COPYING",
    "chars": 34523,
    "preview": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C)"
  },
  {
    "path": "LICENSE.torchrec",
    "chars": 1678,
    "preview": "A few files here (where it is specifically noted in comments) are based on code from torchrec but\nadapted for our use. T"
  },
  {
    "path": "README.md",
    "chars": 503,
    "preview": "This project open sources some of the ML models used at Twitter.\n\nCurrently these are:\n\n1. The \"For You\" Heavy Ranker (p"
  },
  {
    "path": "common/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "common/batch.py",
    "chars": 2508,
    "preview": "\"\"\"Extension of torchrec.dataset.utils.Batch to cover any dataset.\n\"\"\"\n# flake8: noqa\nfrom __future__ import annotations"
  },
  {
    "path": "common/checkpointing/__init__.py",
    "chars": 71,
    "preview": "from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot\n"
  },
  {
    "path": "common/checkpointing/snapshot.py",
    "chars": 9062,
    "preview": "import os\nimport time\nfrom typing import Any, Dict, List, Optional\n\nfrom tml.ml_logging.torch_logging import logging\nfro"
  },
  {
    "path": "common/device.py",
    "chars": 646,
    "preview": "import os\n\nimport torch\nimport torch.distributed as dist\n\n\ndef maybe_setup_tensorflow():\n  try:\n    import tensorflow as"
  },
  {
    "path": "common/filesystem/__init__.py",
    "chars": 72,
    "preview": "from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs\n"
  },
  {
    "path": "common/filesystem/test_infer_fs.py",
    "chars": 353,
    "preview": "\"\"\"Minimal test for infer_fs.\n\nMostly a test that it returns an object\n\"\"\"\nfrom tml.common.filesystem import infer_fs\n\n\n"
  },
  {
    "path": "common/filesystem/util.py",
    "chars": 544,
    "preview": "\"\"\"Utilities for interacting with the file systems.\"\"\"\nfrom fsspec.implementations.local import LocalFileSystem\nimport g"
  },
  {
    "path": "common/log_weights.py",
    "chars": 3355,
    "preview": "\"\"\"For logging model weights.\"\"\"\nimport itertools\nfrom typing import Callable, Dict, List, Optional, Union\n\nfrom tml.ml_"
  },
  {
    "path": "common/modules/embedding/config.py",
    "chars": 1886,
    "preview": "from typing import List\nfrom enum import Enum\n\nimport tml.core.config as base_config\nfrom tml.optimizers.config import O"
  },
  {
    "path": "common/modules/embedding/embedding.py",
    "chars": 1675,
    "preview": "from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType\nfrom tml.ml_logging.torch_logging import"
  },
  {
    "path": "common/run_training.py",
    "chars": 3562,
    "preview": "import os\nimport subprocess\nimport sys\nfrom typing import Optional\n\nfrom tml.ml_logging.torch_logging import logging  # "
  },
  {
    "path": "common/test_device.py",
    "chars": 342,
    "preview": "\"\"\"Minimal test for device.\n\nMostly a test that this can be imported properly even tho moved.\n\"\"\"\nfrom unittest.mock imp"
  },
  {
    "path": "common/testing_utils.py",
    "chars": 640,
    "preview": "from contextlib import contextmanager\nimport datetime\nimport os\nfrom unittest.mock import patch\n\nimport torch.distribute"
  },
  {
    "path": "common/utils.py",
    "chars": 1137,
    "preview": "import yaml\nimport getpass\nimport os\nimport string\nfrom typing import Tuple, Type, TypeVar\n\nfrom tml.core.config import "
  },
  {
    "path": "common/wandb.py",
    "chars": 908,
    "preview": "from typing import Any, Dict, List\n\nimport tml.core.config as base_config\n\nimport pydantic\n\n\nclass WandbConfig(base_conf"
  },
  {
    "path": "core/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "core/config/__init__.py",
    "chars": 246,
    "preview": "from tml.core.config.base_config import BaseConfig\nfrom tml.core.config.config_load import load_config_from_yaml\n\n# Make"
  },
  {
    "path": "core/config/base_config.py",
    "chars": 2228,
    "preview": "\"\"\"Base class for all config (forbids extra fields).\"\"\"\n\nimport collections\nimport functools\nimport yaml\n\nimport pydanti"
  },
  {
    "path": "core/config/base_config_test.py",
    "chars": 926,
    "preview": "from unittest import TestCase\n\nfrom tml.core.config import BaseConfig\n\nimport pydantic\n\n\nclass BaseConfigTest(TestCase):"
  },
  {
    "path": "core/config/config_load.py",
    "chars": 732,
    "preview": "import yaml\nimport string\nimport getpass\nimport os\nfrom typing import Type\n\nfrom tml.core.config.base_config import Base"
  },
  {
    "path": "core/config/test_config_load.py",
    "chars": 550,
    "preview": "from unittest import TestCase\n\nfrom tml.core.config import BaseConfig, load_config_from_yaml\n\nimport pydantic\nimport get"
  },
  {
    "path": "core/config/training.py",
    "chars": 1555,
    "preview": "from typing import Any, Dict, List, Optional\n\nfrom tml.common.wandb import WandbConfig\nfrom tml.core.config import base_"
  },
  {
    "path": "core/custom_training_loop.py",
    "chars": 11286,
    "preview": "\"\"\"Torch and torchrec specific training and evaluation loops.\n\nFeatures (go/100_enablements):\n    - CUDA data-fetch, com"
  },
  {
    "path": "core/debug_training_loop.py",
    "chars": 1162,
    "preview": "\"\"\"This is a very limited feature training loop useful for interactive debugging.\n\nIt is not intended for actual model t"
  },
  {
    "path": "core/loss_type.py",
    "chars": 146,
    "preview": "\"\"\"Loss type enums.\"\"\"\nfrom enum import Enum\n\n\nclass LossType(str, Enum):\n  CROSS_ENTROPY = \"cross_entropy\"\n  BCE_WITH_L"
  },
  {
    "path": "core/losses.py",
    "chars": 2982,
    "preview": "\"\"\"Loss functions -- including multi task ones.\"\"\"\n\nimport typing\n\nfrom tml.core.loss_type import LossType\nfrom tml.ml_l"
  },
  {
    "path": "core/metric_mixin.py",
    "chars": 3091,
    "preview": "\"\"\"\nMixin that requires a transform to munge output dictionary of tensors a\nmodel produces to a form that the torchmetri"
  },
  {
    "path": "core/metrics.py",
    "chars": 5052,
    "preview": "\"\"\"Common metrics that also support multi task.\n\nWe assume multi task models will output [task_idx, ...] predictions\n\n\"\""
  },
  {
    "path": "core/test_metrics.py",
    "chars": 4705,
    "preview": "from dataclasses import dataclass\n\nfrom tml.core import metrics as core_metrics\nfrom tml.core.metric_mixin import Metric"
  },
  {
    "path": "core/test_train_pipeline.py",
    "chars": 2618,
    "preview": "from dataclasses import dataclass\nfrom typing import Tuple\n\nfrom tml.common.batch import DataclassBatch\nfrom tml.common."
  },
  {
    "path": "core/train_pipeline.py",
    "chars": 21915,
    "preview": "\"\"\"\nTaken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py\nwith Tra"
  },
  {
    "path": "images/init_venv.sh",
    "chars": 677,
    "preview": "#! /bin/sh\n\nif [[ \"$(uname)\" == \"Darwin\" ]]; then\n  echo \"Only supported on Linux.\"\n  exit 1\nfi\n\n# You may need to point"
  },
  {
    "path": "images/requirements.txt",
    "chars": 2768,
    "preview": "absl-py==1.4.0\naiofiles==22.1.0\naiohttp==3.8.3\naiosignal==1.3.1\nappdirs==1.4.4\narrow==1.2.3\nasttokens==2.2.1\nastunparse="
  },
  {
    "path": "machines/environment.py",
    "chars": 2437,
    "preview": "import json\nimport os\nfrom typing import List\n\n\nKF_DDS_PORT: int = 5050\nSLURM_DDS_PORT: int = 5051\nFLIGHT_SERVER_PORT: i"
  },
  {
    "path": "machines/get_env.py",
    "chars": 1397,
    "preview": "import tml.machines.environment as env\n\nfrom absl import app, flags\n\n\nFLAGS = flags.FLAGS\nflags.DEFINE_string(\"property\""
  },
  {
    "path": "machines/is_venv.py",
    "chars": 500,
    "preview": "\"\"\"This is intended to be run as a module.\ne.g. python -m tml.machines.is_venv\n\nExits with 0 ii running in venv, otherwi"
  },
  {
    "path": "machines/list_ops.py",
    "chars": 1160,
    "preview": "\"\"\"\nSimple str.split() parsing of input string\n\nusage example:\n  python list_ops.py --input_list=$INPUT [--sep=\",\"] [--o"
  },
  {
    "path": "metrics/__init__.py",
    "chars": 119,
    "preview": "from .aggregation import StableMean  # noqa\nfrom .auroc import AUROCWithMWU  # noqa\nfrom .rce import NRCE, RCE  # noqa\n"
  },
  {
    "path": "metrics/aggregation.py",
    "chars": 3289,
    "preview": "\"\"\"\nContains aggregation metrics.\n\"\"\"\nfrom typing import Tuple, Union\n\nimport torch\nimport torchmetrics\n\n\ndef update_mea"
  },
  {
    "path": "metrics/auroc.py",
    "chars": 6045,
    "preview": "\"\"\"\nAUROC metrics.\n\"\"\"\nfrom typing import Union\n\nfrom tml.ml_logging.torch_logging import logging\n\nimport torch\nimport t"
  },
  {
    "path": "metrics/rce.py",
    "chars": 9927,
    "preview": "\"\"\"\nContains RCE metrics.\n\"\"\"\nimport copy\nfrom functools import partial\nfrom typing import Union\n\nfrom tml.metrics impor"
  },
  {
    "path": "ml_logging/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ml_logging/absl_logging.py",
    "chars": 763,
    "preview": "\"\"\"Sets up logging through absl for training usage.\n\n- Redirects logging to sys.stdout so that severity levels in GCP St"
  },
  {
    "path": "ml_logging/test_torch_logging.py",
    "chars": 514,
    "preview": "import unittest\n\nfrom tml.ml_logging.torch_logging import logging\n\n\nclass Testtlogging(unittest.TestCase):\n  def test_wa"
  },
  {
    "path": "ml_logging/torch_logging.py",
    "chars": 2039,
    "preview": "\"\"\"Overrides absl logger to be rank-aware for distributed pytorch usage.\n\n    >>> # in-bazel import\n    >>> from twitter"
  },
  {
    "path": "model.py",
    "chars": 2840,
    "preview": "\"\"\"Wraps servable model in loss and RecapBatch passing to be trainable.\"\"\"\n# flake8: noqa\nfrom typing import Callable\n\nf"
  },
  {
    "path": "optimizers/__init__.py",
    "chars": 48,
    "preview": "from tml.optimizers.optimizer import compute_lr\n"
  },
  {
    "path": "optimizers/config.py",
    "chars": 2476,
    "preview": "\"\"\"Optimization configurations for models.\"\"\"\n\nimport typing\n\nimport tml.core.config as base_config\n\nimport pydantic\n\n\nc"
  },
  {
    "path": "optimizers/optimizer.py",
    "chars": 3494,
    "preview": "from typing import Dict, Tuple\nimport math\nimport bisect\n\nfrom tml.optimizers.config import (\n  LearningRate,\n  Optimize"
  },
  {
    "path": "projects/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "projects/home/recap/FEATURES.md",
    "chars": 60124,
    "preview": "# Overview\nBelow is a description of the major feature groups which are input to the Twitter Heavy Ranking model.\n\nNote "
  },
  {
    "path": "projects/home/recap/README.md",
    "chars": 4227,
    "preview": "# Heavy Ranker\n\n## Overview\n\nThe heavy ranker is a machine learning model used to rank tweets for the \"For You\" timeline"
  },
  {
    "path": "projects/home/recap/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "projects/home/recap/config/home_recap_2022/segdense.json",
    "chars": 2835,
    "preview": "{\n  \"schema\": [\n    {\n      \"dtype\": \"int64_list\",\n      \"feature_name\": \"home_recap_2022_discrete__segdense_vals\",\n    "
  },
  {
    "path": "projects/home/recap/config/local_prod.yaml",
    "chars": 14319,
    "preview": "training:\n  num_train_steps: 10\n  num_eval_steps: 5\n  checkpoint_every_n: 5\n  train_log_every_n: 1\n  eval_log_every_n: 1"
  },
  {
    "path": "projects/home/recap/config.py",
    "chars": 2008,
    "preview": "from tml.core import config as config_mod\nimport tml.projects.home.recap.data.config as data_config\nimport tml.projects."
  },
  {
    "path": "projects/home/recap/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "projects/home/recap/data/config.py",
    "chars": 8313,
    "preview": "import typing\nfrom enum import Enum\n\n\nfrom tml.core import config as base_config\n\nimport pydantic\n\n\nclass ExplicitDateIn"
  },
  {
    "path": "projects/home/recap/data/dataset.py",
    "chars": 16411,
    "preview": "from dataclasses import dataclass\nfrom typing import Callable, List, Optional, Tuple, Dict\nimport functools\n\nimport torc"
  },
  {
    "path": "projects/home/recap/data/generate_random_data.py",
    "chars": 2617,
    "preview": "import os\nimport json\nfrom absl import app, flags, logging\nimport tensorflow as tf\nfrom typing import Dict\n\nfrom tml.pro"
  },
  {
    "path": "projects/home/recap/data/preprocessors.py",
    "chars": 8251,
    "preview": "\"\"\"\nPreprocessors applied on DDS workers in order to modify the dataset on the fly.\nSome of these preprocessors are also"
  },
  {
    "path": "projects/home/recap/data/tfe_parsing.py",
    "chars": 4409,
    "preview": "import functools\nimport json\n\nfrom tml.projects.home.recap.data import config as recap_data_config\n\nfrom absl import log"
  },
  {
    "path": "projects/home/recap/data/util.py",
    "chars": 3748,
    "preview": "from typing import Mapping, Tuple, Union\nimport torch\nimport torchrec\nimport numpy as np\nimport tensorflow as tf\n\n\ndef k"
  },
  {
    "path": "projects/home/recap/embedding/config.py",
    "chars": 3824,
    "preview": "from typing import List, Optional\nimport tml.core.config as base_config\nfrom tml.optimizers import config as optimizer_c"
  },
  {
    "path": "projects/home/recap/main.py",
    "chars": 3322,
    "preview": "import datetime\nimport os\nfrom typing import Callable, List, Optional, Tuple\nimport tensorflow as tf\n\nimport tml.common."
  },
  {
    "path": "projects/home/recap/model/__init__.py",
    "chars": 202,
    "preview": "from tml.projects.home.recap.model.entrypoint import (\n  create_ranking_model,\n  sanitize,\n  unsanitize,\n  MultiTaskRank"
  },
  {
    "path": "projects/home/recap/model/config.py",
    "chars": 9669,
    "preview": "\"\"\"Configuration for the main Recap model.\"\"\"\n\nimport enum\nfrom typing import List, Optional, Dict\n\nimport tml.core.conf"
  },
  {
    "path": "projects/home/recap/model/entrypoint.py",
    "chars": 11384,
    "preview": "from __future__ import annotations\n\nfrom absl import logging\nimport torch\nfrom typing import Optional, Callable, Mapping"
  },
  {
    "path": "projects/home/recap/model/feature_transform.py",
    "chars": 3821,
    "preview": "from typing import Mapping, Sequence, Union\n\nfrom tml.projects.home.recap.model.config import (\n  BatchNormConfig,\n  Dou"
  },
  {
    "path": "projects/home/recap/model/mask_net.py",
    "chars": 3769,
    "preview": "\"\"\"MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619).\"\"\"\n\nfrom tml.projects.home.recap.model import config, mlp\n\ni"
  },
  {
    "path": "projects/home/recap/model/mlp.py",
    "chars": 1742,
    "preview": "\"\"\"MLP feed forward stack in torch.\"\"\"\n\nfrom tml.projects.home.recap.model.config import MlpConfig\n\nimport torch\nfrom ab"
  },
  {
    "path": "projects/home/recap/model/model_and_loss.py",
    "chars": 2403,
    "preview": "from typing import Callable, Optional, List\nfrom tml.projects.home.recap.embedding import config as embedding_config_mod"
  },
  {
    "path": "projects/home/recap/model/numeric_calibration.py",
    "chars": 546,
    "preview": "import torch\n\n\nclass NumericCalibration(torch.nn.Module):\n  def __init__(\n    self,\n    pos_downsampling_rate: float,\n  "
  },
  {
    "path": "projects/home/recap/optimizer/__init__.py",
    "chars": 72,
    "preview": "from tml.projects.home.recap.optimizer.optimizer import build_optimizer\n"
  },
  {
    "path": "projects/home/recap/optimizer/config.py",
    "chars": 1180,
    "preview": "\"\"\"Optimization configurations for models.\"\"\"\n\nimport typing\n\nimport tml.core.config as base_config\nimport tml.optimizer"
  },
  {
    "path": "projects/home/recap/optimizer/optimizer.py",
    "chars": 6165,
    "preview": "\"\"\"Build optimizers and learning rate schedules.\"\"\"\nimport bisect\nfrom collections import defaultdict\nimport functools\ni"
  },
  {
    "path": "projects/home/recap/script/create_random_data.sh",
    "chars": 348,
    "preview": "#!/usr/bin/env bash\n\n# Runs from inside venv\n\nrm -rf $HOME/tmp/runs/recap_local_random_data\npython -m tml.machines.is_ve"
  },
  {
    "path": "projects/home/recap/script/run_local.sh",
    "chars": 391,
    "preview": "#!/usr/bin/env bash\n\n# Runs from inside venv\nrm -rf $HOME/tmp/runs/recap_local_debug\nmkdir -p $HOME/tmp/runs/recap_local"
  },
  {
    "path": "projects/twhin/README.md",
    "chars": 1205,
    "preview": "Twhin in torchrec\n\nThis project contains code for pretraining dense vector embedding features for Twitter entities. With"
  },
  {
    "path": "projects/twhin/config/local.yaml",
    "chars": 1469,
    "preview": "runtime:\n  enable_amp: false\ntraining:\n  save_dir: \"/tmp/model\"\n  num_train_steps: 100000\n  checkpoint_every_n: 100000\n "
  },
  {
    "path": "projects/twhin/config.py",
    "chars": 502,
    "preview": "from tml.core.config import base_config\nfrom tml.projects.twhin.data.config import TwhinDataConfig\nfrom tml.projects.twh"
  },
  {
    "path": "projects/twhin/data/config.py",
    "chars": 361,
    "preview": "from tml.core.config import base_config\n\nimport pydantic\n\n\nclass TwhinDataConfig(base_config.BaseConfig):\n  data_root: s"
  },
  {
    "path": "projects/twhin/data/data.py",
    "chars": 618,
    "preview": "from tml.projects.twhin.data.config import TwhinDataConfig\nfrom tml.projects.twhin.models.config import TwhinModelConfig"
  },
  {
    "path": "projects/twhin/data/edges.py",
    "chars": 5120,
    "preview": "from collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Tuple\n\nfrom tml.com"
  },
  {
    "path": "projects/twhin/data/test_data.py",
    "chars": 81,
    "preview": "import pytest\nfrom unittest.mock import Mock\n\n\ndef test_create_dataset():\n  pass\n"
  },
  {
    "path": "projects/twhin/data/test_edges.py",
    "chars": 1724,
    "preview": "\"\"\"Tests edges dataset functionality.\"\"\"\n\nfrom unittest.mock import patch\nimport os\nimport tempfile\n\nfrom tml.projects.t"
  },
  {
    "path": "projects/twhin/machines.yaml",
    "chars": 183,
    "preview": "chief: &gpu\n  mem: 1.4Ti\n  cpu: 24\n  num_accelerators: 16\n  accelerator_type: a100\ndataset_dispatcher:\n  mem: 2Gi\n  cpu:"
  },
  {
    "path": "projects/twhin/metrics.py",
    "chars": 287,
    "preview": "import torch\nimport torchmetrics as tm\n\nimport tml.core.metrics as core_metrics\n\n\ndef create_metrics(\n  device: torch.de"
  },
  {
    "path": "projects/twhin/models/config.py",
    "chars": 1831,
    "preview": "import typing\nimport enum\n\nfrom tml.common.modules.embedding.config import LargeEmbeddingsConfig\nfrom tml.core.config im"
  },
  {
    "path": "projects/twhin/models/models.py",
    "chars": 5454,
    "preview": "from typing import Callable\nimport math\n\nfrom tml.projects.twhin.data.edges import EdgeBatch\nfrom tml.projects.twhin.mod"
  },
  {
    "path": "projects/twhin/models/test_models.py",
    "chars": 2897,
    "preview": "from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig\nfrom tml.projects.twhin.data.config"
  },
  {
    "path": "projects/twhin/optimizer.py",
    "chars": 2166,
    "preview": "import functools\n\nfrom tml.projects.twhin.models.config import TwhinModelConfig\nfrom tml.projects.twhin.models.models im"
  },
  {
    "path": "projects/twhin/run.py",
    "chars": 3222,
    "preview": "from absl import app, flags\nimport json\nfrom typing import Optional\nimport os\nimport sys\n\nimport torch\n\n# isort: on\nfrom"
  },
  {
    "path": "projects/twhin/scripts/docker_run.sh",
    "chars": 276,
    "preview": "#! /bin/sh\n\ndocker run -it --rm \\\n  -v $HOME/workspace/tml:/usr/src/app/tml \\\n  -v $HOME/.config:/root/.config \\\n  -w /u"
  },
  {
    "path": "projects/twhin/scripts/run_in_docker.sh",
    "chars": 226,
    "preview": "#! /bin/sh\n\ntorchrun \\\n  --standalone \\\n  --nnodes 1 \\\n  --nproc_per_node 2 \\\n  /usr/src/app/tml/projects/twhin/run.py \\"
  },
  {
    "path": "projects/twhin/test_optimizer.py",
    "chars": 1096,
    "preview": "import pytest\nimport unittest\n\nfrom tml.projects.twhin.models.models import TwhinModel, apply_optimizers\nfrom tml.projec"
  },
  {
    "path": "pyproject.toml",
    "chars": 185,
    "preview": "[tool.black]\nline-length = 100\ninclude = '\\.pyi?$'\nexclude = '''\n/(\n    \\.git\n  | \\.hg\n  | \\.pem\n  | \\.mypy_cache\n  | \\."
  },
  {
    "path": "reader/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "reader/dataset.py",
    "chars": 4451,
    "preview": "\"\"\"Dataset to be overwritten that can work with or without distributed reading.\n\n- Override `pa_to_batch` for dataset sp"
  },
  {
    "path": "reader/dds.py",
    "chars": 3762,
    "preview": "\"\"\"Dataset service orchestrated by a TFJob\n\"\"\"\nfrom typing import Optional\nimport uuid\n\nfrom tml.ml_logging.torch_loggin"
  },
  {
    "path": "reader/test_dataset.py",
    "chars": 1969,
    "preview": "import multiprocessing as mp\nimport os\nfrom unittest.mock import patch\n\nimport tml.reader.utils as reader_utils\nfrom tml"
  },
  {
    "path": "reader/test_utils.py",
    "chars": 211,
    "preview": "import tml.reader.utils as reader_utils\n\n\ndef test_rr():\n  options = [\"a\", \"b\", \"c\"]\n  rr = reader_utils.roundrobin(opti"
  },
  {
    "path": "reader/utils.py",
    "chars": 2537,
    "preview": "\"\"\"Reader utilities.\"\"\"\nimport itertools\nimport time\nfrom typing import Optional\n\nfrom tml.common.batch import Dataclass"
  },
  {
    "path": "tools/pq.py",
    "chars": 2620,
    "preview": "\"\"\"Local reader of parquet files.\n\n1. Make sure you are initialized locally:\n  ```\n  ./images/init_venv_macos.sh\n  ```\n2"
  }
]

About this extraction

This page contains the full source code of the twitter/the-algorithm-ml GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 111 files (376.7 KB), approximately 98.5k tokens, and a symbol index with 456 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!