Showing preview only (1,216K chars total). Download the full file or copy to clipboard to get everything.
Repository: yz93/LAVT-RIS
Branch: main
Commit: 1da0af9f21b6
Files: 59
Total size: 1.1 MB
Directory structure:
gitextract_240sk8kr/
├── LICENSE
├── README.md
├── args.py
├── bert/
│ ├── activations.py
│ ├── configuration_bert.py
│ ├── configuration_utils.py
│ ├── file_utils.py
│ ├── generation_utils.py
│ ├── modeling_bert.py
│ ├── modeling_utils.py
│ ├── tokenization_bert.py
│ ├── tokenization_utils.py
│ └── tokenization_utils_base.py
├── data/
│ └── dataset_refer_bert.py
├── demo_inference.py
├── lib/
│ ├── _utils.py
│ ├── backbone.py
│ ├── mask_predictor.py
│ ├── mmcv_custom/
│ │ ├── __init__.py
│ │ └── checkpoint.py
│ └── segmentation.py
├── refer/
│ ├── LICENSE
│ ├── Makefile
│ ├── README.md
│ ├── data/
│ │ └── README.md
│ ├── evaluation/
│ │ ├── __init__.py
│ │ ├── bleu/
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ ├── bleu.py
│ │ │ └── bleu_scorer.py
│ │ ├── cider/
│ │ │ ├── __init__.py
│ │ │ ├── cider.py
│ │ │ └── cider_scorer.py
│ │ ├── meteor/
│ │ │ ├── __init__.py
│ │ │ └── meteor.py
│ │ ├── readme.txt
│ │ ├── refEvaluation.py
│ │ ├── rouge/
│ │ │ ├── __init__.py
│ │ │ └── rouge.py
│ │ └── tokenizer/
│ │ ├── __init__.py
│ │ ├── ptbtokenizer.py
│ │ └── stanford-corenlp-3.4.1.jar
│ ├── external/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── _mask.pyx
│ │ ├── mask.py
│ │ ├── maskApi.c
│ │ └── maskApi.h
│ ├── pyEvalDemo.ipynb
│ ├── pyReferDemo.ipynb
│ ├── refer.py
│ ├── setup.py
│ └── test/
│ ├── sample_expressions_testA.json
│ └── sample_expressions_testB.json
├── requirements.txt
├── test.py
├── train.py
├── transforms.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
================================================
FILE: README.md
================================================
# LAVT: Language-Aware Vision Transformer for Referring Image Segmentation
Welcome to the official repository for the method presented in
"LAVT: Language-Aware Vision Transformer for Referring Image Segmentation."

Code in this repository is written using [PyTorch](https://pytorch.org/) and is organized in the following way (assuming the working directory is the root directory of this repository):
* `./lib` contains files implementing the main network.
* Inside `./lib`, `_utils.py` defines the highest-level model, which incorporates the backbone network
defined in `backbone.py` and the simple mask decoder defined in `mask_predictor.py`.
`segmentation.py` provides the model interface and initialization functions.
* `./bert` contains files migrated from [Hugging Face Transformers v3.0.2](https://huggingface.co/transformers/v3.0.2/quicktour.html),
which implement the BERT language model.
We used Transformers v3.0.2 during development but it had a bug that would appear when using `DistributedDataParallel`.
Therefore we maintain a copy of the relevant source files in this repository.
This way, the bug is fixed and code in this repository is self-contained.
* `./train.py` is invoked to train the model.
* `./test.py` is invoked to run inference on the evaluation subsets after training.
* `./refer` contains data pre-processing code and is also where data should be placed, including the images and all annotations.
It is cloned from [refer](https://github.com/lichengunc/refer).
* `./data/dataset_refer_bert.py` is where the dataset class is defined.
* `./utils.py` defines functions that track training statistics and setup
functions for `DistributedDataParallel`.
## Updates
**April 13<sup>th</sup>, 2023**. Using the Dice loss instead of the cross-entropy loss can improve results. Will add code and release weights later when get a chance.
**June 21<sup>st</sup>, 2022**. Uploaded the training logs and trained
model weights of lavt_one.
**June 9<sup>th</sup>, 2022**.
Added a more efficient implementation of LAVT.
* To train this new model, specify `--model` as `lavt_one`
(and `lavt` is still valid for specifying the old model).
The rest of the configuration stays unchanged.
* The difference between this version and the previous one
is that the language model has been moved inside the overall model,
so that `DistributedDataParallel` needs to be applied only once.
Applying it twice (on the standalone language model and the main branch)
as done in the old implementation led to low GPU utility,
which slowed down training.
We recommend training this model on 8 GPUs
(and same as before with batch size 32).
## Setting Up
### Preliminaries
The code has been verified to work with PyTorch v1.7.1 and Python 3.7.
1. Clone this repository.
2. Change directory to root of this repository.
### Package Dependencies
1. Create a new Conda environment with Python 3.7 then activate it:
```shell
conda create -n lavt python==3.7
conda activate lavt
```
2. Install PyTorch v1.7.1 with a CUDA version that works on your cluster/machine (CUDA 10.2 is used in this example):
```shell
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch
```
3. Install the packages in `requirements.txt` via `pip`:
```shell
pip install -r requirements.txt
```
### Datasets
1. Follow instructions in the `./refer` directory to set up subdirectories
and download annotations.
This directory is a git clone (minus two data files that we do not need)
from the [refer](https://github.com/lichengunc/refer) public API.
2. Download images from [COCO](https://cocodataset.org/#download).
Please use the first downloading link *2014 Train images [83K/13GB]*, and extract
the downloaded `train_2014.zip` file to `./refer/data/images/mscoco/images`.
### The Initialization Weights for Training
1. Create the `./pretrained_weights` directory where we will be storing the weights.
```shell
mkdir ./pretrained_weights
```
2. Download [pre-trained classification weights of
the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth),
and put the `pth` file in `./pretrained_weights`.
These weights are needed for training to initialize the model.
### Trained Weights of LAVT for Testing
1. Create the `./checkpoints` directory where we will be storing the weights.
```shell
mkdir ./checkpoints
```
2. Download LAVT model weights (which are stored on Google Drive) using links below and put them in `./checkpoints`.
| [RefCOCO](https://drive.google.com/file/d/13D-OeEOijV8KTC3BkFP-gOJymc6DLwVT/view?usp=sharing) | [RefCOCO+](https://drive.google.com/file/d/1B8Q44ZWsc8Pva2xD_M-KFh7-LgzeH2-2/view?usp=sharing) | [G-Ref (UMD)](https://drive.google.com/file/d/1BjUnPVpALurkGl7RXXvQiAHhA-gQYKvK/view?usp=sharing) | [G-Ref (Google)](https://drive.google.com/file/d/1weiw5UjbPfo3tCBPfB8tu6xFXCUG16yS/view?usp=sharing) |
|---|---|---|---|
3. Model weights and training logs of the new lavt_one implementation are below.
| RefCOCO | RefCOCO+ | G-Ref (UMD) | G-Ref (Google) |
|:-----:|:-----:|:-----:|:-----:|
|[log](https://drive.google.com/file/d/1YIojIHqe3bxxsWOltifa2U9jH67hPHLM/view?usp=sharing) | [weights](https://drive.google.com/file/d/1xFMEXr6AGU97Ypj1yr8oo00uObbeIQvJ/view?usp=sharing)|[log](https://drive.google.com/file/d/1Z34T4gEnWlvcSUQya7txOuM0zdLK7MRT/view?usp=sharing) | [weights](https://drive.google.com/file/d/1HS8ZnGaiPJr-OmoUn4-4LVnVtD_zHY6w/view?usp=sharing)|[log](https://drive.google.com/file/d/14VAgahngOV8NA6noLZCqDoqaUrlW14v8/view?usp=sharing) | [weights](https://drive.google.com/file/d/14g8NzgZn6HzC6tP_bsQuWmh5LnOcovsE/view?usp=sharing)|[log](https://drive.google.com/file/d/1JBXfmlwemWSvs92Rky0TlHcVuuLpt4Da/view?usp=sharing) | [weights](https://drive.google.com/file/d/1IJeahFVLgKxu_BVmWacZs3oUzgTCeWcz/view?usp=sharing)|
* The Prec@K, overall IoU and mean IoU numbers in the training logs will differ
from the final results obtained by running `test.py`,
because only one out of multiple annotated expressions is
randomly selected and evaluated for each object during training.
But these numbers give a good idea about the test performance.
The two should be fairly close.
## Training
We use `DistributedDataParallel` from PyTorch.
The released `lavt` weights were trained using 4 x 32G V100 cards (max mem on each card was about 26G).
The released `lavt_one` weights were trained using 8 x 32G V100 cards (max mem on each card was about 13G).
Using more cards was to accelerate training.
To run on 4 GPUs (with IDs 0, 1, 2, and 3) on a single node:
```shell
mkdir ./models
mkdir ./models/refcoco
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco --model_id refcoco --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco/output
mkdir ./models/refcoco+
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco+ --model_id refcoco+ --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco+/output
mkdir ./models/gref_umd
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy umd --model_id gref_umd --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_umd/output
mkdir ./models/gref_google
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy google --model_id gref_google --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_google/output
```
* *--model* is a pre-defined model name. Options include `lavt` and `lavt_one`. See [Updates](#updates).
* *--dataset* is the dataset name. One can choose from `refcoco`, `refcoco+`, and `refcocog`.
* *--splitBy* needs to be specified if and only if the dataset is G-Ref (which is also called RefCOCOg).
`umd` identifies the UMD partition and `google` identifies the Google partition.
* *--model_id* is the model name one should define oneself (*e.g.*, customize it to contain training/model configurations, dataset information, experiment IDs, *etc*.).
It is used in two ways: Training log will be saved as `./models/[args.model_id]/output` and the best checkpoint will be saved as `./checkpoints/model_best_[args.model_id].pth`.
* *--swin_type* specifies the version of the Swin Transformer.
One can choose from `tiny`, `small`, `base`, and `large`. The default is `base`.
* *--pretrained_swin_weights* specifies the path to pre-trained Swin Transformer weights used for model initialization.
* Note that currently we need to manually create the `./models/[args.model_id]` directory via `mkdir` before running `train.py`.
This is because we use `tee` to redirect `stdout` and `stderr` to `./models/[args.model_id]/output` for logging.
This is a nuisance and should be resolved in the future, *i.e.*, using a proper logger or a bash script for initiating training.
## Testing
For RefCOCO/RefCOCO+, run one of
```shell
python test.py --model lavt --swin_type base --dataset refcoco --split val --resume ./checkpoints/refcoco.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
python test.py --model lavt --swin_type base --dataset refcoco+ --split val --resume ./checkpoints/refcoco+.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
```
* *--split* is the subset to evaluate, and one can choose from `val`, `testA`, and `testB`.
* *--resume* is the path to the weights of a trained model.
For G-Ref (UMD)/G-Ref (Google), run one of
```shell
python test.py --model lavt --swin_type base --dataset refcocog --splitBy umd --split val --resume ./checkpoints/gref_umd.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
python test.py --model lavt --swin_type base --dataset refcocog --splitBy google --split val --resume ./checkpoints/gref_google.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
```
* *--splitBy* specifies the partition to evaluate.
One can choose from `umd` or `google`.
* *--split* is the subset (according to the specified partition) to evaluate, and one can choose from `val` and `test` for the UMD partition, and only `val` for the Google partition..
* *--resume* is the path to the weights of a trained model.
## Results
1. The evaluation results (those reported in the paper) of LAVT trained with a cross-entropy loss and based on our original implementation are summarized as follows:
| Dataset | P@0.5 | P@0.6 | P@0.7 | P@0.8 | P@0.9 | Overall IoU | Mean IoU |
|:---------------:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----------:|:--------:|
| RefCOCO val | 84.46 | 80.90 | 75.28 | 64.71 | 34.30 | 72.73 | 74.46 |
| RefCOCO test A | 88.07 | 85.17 | 79.90 | 68.52 | 35.69 | 75.82 | 76.89 |
| RefCOCO test B | 79.12 | 74.94 | 69.17 | 59.37 | 34.45 | 68.79 | 70.94 |
| RefCOCO+ val | 74.44 | 70.91 | 65.58 | 56.34 | 30.23 | 62.14 | 65.81 |
| RefCOCO+ test A | 80.68 | 77.96 | 72.90 | 62.21 | 32.36 | 68.38 | 70.97 |
| RefCOCO+ test B | 65.66 | 61.85 | 55.94 | 47.56 | 27.24 | 55.10 | 59.23 |
| G-Ref val (UMD) | 70.81 | 65.28 | 58.60 | 47.49 | 22.73 | 61.24 | 63.34 |
| G-Ref test (UMD)| 71.54 | 66.38 | 59.00 | 48.21 | 23.10 | 62.09 | 63.62 |
|G-Ref val (Goog.)| 71.16 | 67.21 | 61.76 | 51.98 | 27.30 | 60.50 | 63.66 |
- We have validated LAVT on RefCOCO with multiple runs. The overall IoU on the val set generally lies in the range of 72.73±0.5%.
2. In the following, we report the results of LAVT trained with a multi-class Dice loss and based on the new implementation (`lavt_one`).
| Dataset | P@0.5 | P@0.6 | P@0.7 | P@0.8 | P@0.9 | Overall IoU | Mean IoU |
|:---------------:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----------:|:--------:|
| RefCOCO val | 85.87 | 82.13 | 76.64 | 65.45 | 35.30 | 73.50 | 75.41 |
| RefCOCO test A | 88.47 | 85.63 | 80.57 | 68.84 | 35.71 | 75.97 | 77.31 |
| RefCOCO test B | 80.20 | 76.49 | 70.34 | 60.12 | 34.94 | 69.33 | 71.86 |
| RefCOCO+ val | 76.19 | 72.27 | 66.82 | 56.87 | 30.15 | 63.79 | 67.65 |
| RefCOCO+ test A | 82.50 | 79.44 | 74.00 | 63.27 | 31.99 | 69.79 | 72.53 |
| RefCOCO+ test B | 68.03 | 63.35 | 57.29 | 47.92 | 26.98 | 56.49 | 61.22 |
| G-Ref val (UMD) | 75.82 | 71.06 | 63.99 | 52.98 | 27.31 | 64.02 | 67.41 |
| G-Ref test (UMD)| 76.12 | 71.13 | 64.58 | 53.62 | 28.03 | 64.49 | 67.45 |
|G-Ref val (Goog.)| 72.57 | 68.65 | 63.09 | 53.33 | 28.14 | 61.31 | 64.84 |
## Demo: Try LAVT on Your Own Image-Text Pairs
You can run inference on any image-text pair
and visualize the result by running the script `./demo_inference.py`.
Have fun!
## Citing LAVT
```
@inproceedings{yang2022lavt,
title={LAVT: Language-Aware Vision Transformer for Referring Image Segmentation},
author={Yang, Zhao and Wang, Jiaqi and Tang, Yansong and Chen, Kai and Zhao, Hengshuang and Torr, Philip HS},
booktitle={CVPR},
year={2022}
}
```
## Contributing
We appreciate all contributions.
It helps the project if you could
- report issues you are facing,
- give a :+1: on issues reported by others that are relevant to you,
- answer issues reported by others for which you have found solutions,
- and implement helpful new features or improve the code otherwise with pull requests.
## Acknowledgements
Code in this repository is built upon several public repositories.
Specifically,
* data pre-processing leverages the [refer](https://github.com/lichengunc/refer) repository,
* the backbone model is implemented based on code from [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation),
* the training and testing pipelines are adapted from [RefVOS](https://github.com/miriambellver/refvos),
* and implementation of the BERT model (files in the bert directory) is from [Hugging Face Transformers v3.0.2](https://github.com/huggingface/transformers/tree/v3.0.2)
(we migrated over the relevant code to fix a bug and simplify the installation process).
Some of these repositories in turn adapt code from [OpenMMLab](https://github.com/open-mmlab) and [TorchVision](https://github.com/pytorch/vision).
We'd like to thank the authors/organizations of these repositories for open sourcing their projects.
## License
GNU GPLv3
================================================
FILE: args.py
================================================
import argparse
def get_parser():
parser = argparse.ArgumentParser(description='LAVT training and testing')
parser.add_argument('--amsgrad', action='store_true',
help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
parser.add_argument('--ddp_trained_weights', action='store_true',
help='Only needs specified when testing,'
'whether the weights to be loaded are from a DDP-trained model')
parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
parser.add_argument('--img_size', default=480, type=int, help='input image size')
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
'where a, b, c, and d refer to the numbers of heads in stage-1,'
'stage-2, stage-3, and stage-4 PWAMs')
parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one')
parser.add_argument('--model_id', default='lavt', help='name to identify the model')
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
parser.add_argument('--pin_mem', action='store_true',
help='If true, pin memory when using the data loader.')
parser.add_argument('--pretrained_swin_weights', default='',
help='path to pre-trained Swin backbone weights')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--split', default='test', help='only used when testing')
parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
parser.add_argument('--swin_type', default='base',
help='tiny, small, base, or large variants of the Swin Transformer')
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
dest='weight_decay')
parser.add_argument('--window12', action='store_true',
help='only needs specified when testing,'
'when training, window size is inferred from pre-trained weights file name'
'(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
return parser
if __name__ == "__main__":
parser = get_parser()
args_dict = parser.parse_args()
================================================
FILE: bert/activations.py
================================================
import logging
import math
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
def swish(x):
return x * torch.sigmoid(x)
def _gelu_python(x):
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
This is now written in C in torch.nn.functional
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if torch.__version__ < "1.4.0":
gelu = _gelu_python
else:
gelu = F.gelu
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
ACT2FN = {
"relu": F.relu,
"swish": swish,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
================================================
FILE: bert/configuration_bert.py
================================================
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT model configuration """
import logging
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
"cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
"cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
"TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
"wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
# See all BERT models at https://huggingface.co/models?filter=bert
}
class BertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
It is used to instantiate an BERT model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
vocab_size (:obj:`int`, optional, defaults to 30522):
Vocabulary size of the BERT model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
hidden_size (:obj:`int`, optional, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, optional, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, optional, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
The non-linear activation function (function or string) in the encoder and pooler.
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, optional, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, optional, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to False):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
>>> from transformers import BertModel, BertConfig
>>> # Initializing a BERT bert-base-uncased style configuration
>>> configuration = BertConfig()
>>> # Initializing a model from the bert-base-uncased style configuration
>>> model = BertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "bert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
================================================
FILE: bert/configuration_utils.py
================================================
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""
import copy
import json
import logging
import os
from typing import Dict, Tuple
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
logger = logging.getLogger(__name__)
class PretrainedConfig(object):
r""" Base class for all configuration classes.
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
Note:
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
It only affects the model's configuration.
Class attributes (overridden by derived classes):
- ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
Args:
finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
num_labels (:obj:`int`, `optional`, defaults to `2`):
Number of classes to use when the model is a classification model (sequences/tokens)
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
Should the model returns all hidden-states.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
Should the model returns all attentions.
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
Is the model used with Torchscript (for PyTorch models).
"""
model_type: str = ""
def __init__(self, **kwargs):
# Attributes with defaults
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
self.min_length = kwargs.pop("min_length", 0)
self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.id2label = kwargs.pop("id2label", None)
self.label2id = kwargs.pop("label2id", None)
if self.id2label is not None:
kwargs.pop("num_labels", None)
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
# Keys are always strings in JSON so convert ids to int here.
else:
self.num_labels = kwargs.pop("num_labels", 2)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)
# TPU arguments
self.xla_device = kwargs.pop("xla_device", None)
# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
@property
def num_labels(self):
return len(self.id2label)
@num_labels.setter
def num_labels(self, num_labels):
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
def save_pretrained(self, save_directory):
"""
Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
Args:
save_directory (:obj:`string`):
Directory where the configuration JSON file will be saved.
"""
if os.path.isfile(save_directory):
raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
self.to_json_file(output_config_file, use_diff=True)
logger.info("Configuration saved in {}".format(output_config_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
r"""
Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
Args:
pretrained_model_name_or_path (:obj:`string`):
either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or
download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the
:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g.:
``./my_model_directory/configuration.json``.
cache_dir (:obj:`string`, `optional`):
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
kwargs (:obj:`Dict[str, any]`, `optional`):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
controlled by the `return_unused_kwargs` keyword parameter.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies (:obj:`Dict`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g.:
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
The proxies are used on each request.
return_unused_kwargs: (`optional`) bool:
If False, then this function returns just the final configuration object.
If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
of kwargs which has not been used to update `config` and is otherwise ignored.
Returns:
:class:`PretrainedConfig`: An instance of a configuration object
Examples::
# We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
# derived class: BertConfig
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
assert config.output_attention == True
config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
foo=False, return_unused_kwargs=True)
assert config.output_attention == True
assert unused_kwargs == {'foo': False}
"""
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
for instantiating a Config using `from_dict`.
Parameters:
pretrained_model_name_or_path (:obj:`string`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
Returns:
:obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
)
# Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError:
msg = (
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
)
raise EnvironmentError(msg)
except json.JSONDecodeError:
msg = (
"Couldn't reach server at '{}' to download configuration file or "
"configuration file is not a valid JSON file. "
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
)
raise EnvironmentError(msg)
if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
return config_dict, kwargs
@classmethod
def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
"""
Constructs a `Config` from a Python dictionary of parameters.
Args:
config_dict (:obj:`Dict[str, any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
method.
kwargs (:obj:`Dict[str, any]`):
Additional parameters from which to initialize the configuration object.
Returns:
:class:`PretrainedConfig`: An instance of a configuration object
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
config = cls(**config_dict)
if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config %s", str(config))
if return_unused_kwargs:
return config, kwargs
else:
return config
@classmethod
def from_json_file(cls, json_file: str) -> "PretrainedConfig":
"""
Constructs a `Config` from the path to a json file of parameters.
Args:
json_file (:obj:`string`):
Path to the JSON file containing the parameters.
Returns:
:class:`PretrainedConfig`: An instance of a configuration object
"""
config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)
@classmethod
def _dict_from_json_file(cls, json_file: str):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return "{} {}".format(self.__class__.__name__, self.to_json_string())
def to_diff_dict(self):
"""
Removes all attributes from config which correspond to the default
config attributes for better readability and serializes to a Python
dictionary.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()
# get the default config dict
default_config_dict = PretrainedConfig().to_dict()
serializable_config_dict = {}
# only serialize values that differ from the default config
for key, value in config_dict.items():
if key not in default_config_dict or value != default_config_dict[key]:
serializable_config_dict[key] = value
return serializable_config_dict
def to_dict(self):
"""
Serializes this instance to a Python dictionary.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
return output
def to_json_string(self, use_diff=True):
"""
Serializes this instance to a JSON string.
Args:
use_diff (:obj:`bool`):
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
Returns:
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
"""
if use_diff is True:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path, use_diff=True):
"""
Save this instance to a json file.
Args:
json_file_path (:obj:`string`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (:obj:`bool`):
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))
def update(self, config_dict: Dict):
"""
Updates attributes of this class
with attributes from `config_dict`.
Args:
:obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
"""
for key, value in config_dict.items():
setattr(self, key, value)
================================================
FILE: bert/file_utils.py
================================================
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
import fnmatch
import json
import logging
import os
import shutil
import sys
import tarfile
import tempfile
from contextlib import contextmanager
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from typing import Dict, Optional, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import requests
from filelock import FileLock
from tqdm.auto import tqdm
#from . import __version__
__version__ = "3.0.2"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
import torch
_torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__))
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
try:
import torch_xla.core.xla_model as xm # noqa: F401
if _torch_available:
_torch_tpu_available = True # pylint: disable=
else:
_torch_tpu_available = False
except ImportError:
_torch_tpu_available = False
try:
import psutil # noqa: F401
_psutil_available = True
except ImportError:
_psutil_available = False
try:
import py3nvml # noqa: F401
_py3nvml_available = True
except ImportError:
_py3nvml_available = False
try:
from apex import amp # noqa: F401
_has_apex = True
except ImportError:
_has_apex = False
default_cache_path = os.path.join(torch_cache_home, "transformers")
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "modelcard.json"
MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available
def is_torch_tpu_available():
return _torch_tpu_available
def is_psutil_available():
return _psutil_available
def is_py3nvml_available():
return _py3nvml_available
def is_apex_available():
return _has_apex
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
def add_start_docstrings_to_callable(*docstr):
def docstring_decorator(fn):
class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
note = r"""
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
pre and post processing steps while the latter silently ignores them.
"""
fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
def add_end_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + "".join(docstr)
return fn
return docstring_decorator
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
>>> outputs = model(**inputs, labels=labels)
>>> loss, scores = outputs[:2]
"""
PT_QUESTION_ANSWERING_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])
>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
>>> loss, start_scores, end_scores = outputs[:3]
"""
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
>>> outputs = model(**inputs, labels=labels)
>>> loss, logits = outputs[:2]
"""
PT_MASKED_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
>>> outputs = model(input_ids, labels=input_ids)
>>> loss, prediction_scores = outputs[:2]
"""
PT_BASE_MODEL_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
PT_MULTIPLE_CHOICE_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
>>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
>>> # the linear classifier still needs to be trained
>>> loss, logits = outputs[:2]
"""
PT_CAUSAL_LM_SAMPLE = r"""
Example::
>>> import torch
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> loss, logits = outputs[:2]
"""
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> input_ids = inputs["input_ids"]
>>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
>>> outputs = model(inputs)
>>> loss, scores = outputs[:2]
"""
TF_QUESTION_ANSWERING_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> input_dict = tokenizer(question, text, return_tensors='tf')
>>> start_scores, end_scores = model(input_dict)
>>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
>>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
>>> outputs = model(inputs)
>>> loss, logits = outputs[:2]
"""
TF_MASKED_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
>>> outputs = model(input_ids)
>>> prediction_scores = outputs[0]
"""
TF_BASE_MODEL_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> outputs = model(inputs)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
TF_MULTIPLE_CHOICE_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
>>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
>>> outputs = model(inputs) # batch size is 1
>>> # the linear classifier still needs to be trained
>>> logits = outputs[0]
"""
TF_CAUSAL_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> outputs = model(inputs)
>>> logits = outputs[0]
"""
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0]
is_tf_class = model_class[:2] == "TF"
if "SequenceClassification" in model_class:
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
elif "QuestionAnswering" in model_class:
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
elif "TokenClassification" in model_class:
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
elif "MultipleChoice" in model_class:
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
elif "MaskedLM" in model_class:
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
elif "LMHead" in model_class:
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
elif "Model" in model_class:
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
return fn
return docstring_decorator
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
"""
Resolve a model identifier, and a file name, to a HF-hosted url
on either S3 or Cloudfront (a Content Delivery Network, or CDN).
Cloudfront is replicated over the globe so downloads are way faster
for the end user (and it also lowers our bandwidth costs). However, it
is more aggressively cached by default, so may not always reflect the
latest changes to the underlying file (default TTL is 24 hours).
In terms of client-side caching from this library, even though
Cloudfront relays the ETags from S3, using one or the other
(or switching from one to the other) will affect caching: cached files
are not shared between the two because the cached file's name contains
a hash of the url.
"""
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
if url.endswith(".h5"):
filename += ".h5"
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def cached_path(
url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
extract_compressed_file=False,
force_extract=False,
local_files_only=False,
) -> Optional[str]:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
re-extract the archive and overide the folder where it was extracted.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
local_files_only=local_files_only,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
else:
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
return output_path_extracted
return output_path
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
ua += "; torch/{}".format(torch.__version__)
if is_tf_available():
ua += "; tensorflow/{}".format(tf.__version__)
if isinstance(user_agent, dict):
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
headers = {"user-agent": ua}
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable
return
content_length = response.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc="Downloading",
disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
)
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(
url,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=10,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
Given a URL, look for the corresponding file in the local cache.
If it's not there, download it. Then return the path to the cached file.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
etag = None
if not local_files_only:
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
# etag is already None
pass
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# try to get the last downloaded one
if etag is None:
if os.path.exists(cache_path):
return cache_path
else:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise ValueError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
return None
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
return cache_path
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
with FileLock(lock_path):
# If the download just completed while the lock was activated.
if os.path.exists(cache_path) and not force_download:
# Even if returning early like here, the lock will be released.
return cache_path
if resume_download:
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager():
with open(incomplete_path, "a+b") as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path
class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
From tensorflow_datasets
Built-in in functools from Python 3.8.
"""
def __get__(self, obj, objtype=None):
# See docs.python.org/3/howto/descriptor.html#properties
if obj is None:
return self
if self.fget is None:
raise AttributeError("unreadable attribute")
attr = "__cached_" + self.fget.__name__
cached = getattr(obj, attr, None)
if cached is None:
cached = self.fget(obj)
setattr(obj, attr, cached)
return cached
def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
return wrapper
def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper
================================================
FILE: bert/generation_utils.py
================================================
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import functional as F
logger = logging.getLogger(__name__)
class GenerationMixin:
"""
A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
"""
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def adjust_logits_during_generation(self, logits, **kwargs):
return logits
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
if len(outputs) <= 1 or use_cache is False:
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for i in range(batch_size * num_beams):
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
def postprocess_next_token_scores(
self,
scores,
input_ids,
no_repeat_ngram_size,
bad_words_ids,
cur_len,
min_length,
max_length,
eos_token_id,
repetition_penalty,
batch_size,
num_beams,
):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(
scores, batch_size, num_beams, input_ids, repetition_penalty,
)
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_batch_tokens = calc_banned_ngram_tokens(
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
for i, banned_tokens in enumerate(banned_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_specific_kwargs
) -> torch.LongTensor:
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Adapted in part from `Facebook's XLM beam search code`_.
.. _`Facebook's XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
Parameters:
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape `(1,)`.
max_length: (`optional`) int
The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
min_length: (`optional`) int
The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
do_sample: (`optional`) bool
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
early_stopping: (`optional`) bool
if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
temperature: (`optional`) float
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
top_k: (`optional`) int
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
top_p: (`optional`) float
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
pad_token_id: (`optional`) int
Padding token. Default to specicic model pad_token_id or None if it does not exist.
bos_token_id: (`optional`) int
BOS token. Defaults to `bos_token_id` as defined in the models config.
eos_token_id: (`optional`) int
EOS token. Defaults to `eos_token_id` as defined in the models config.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
no_repeat_ngram_size: (`optional`) int
If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
bad_words_ids: (`optional`) list of lists of int
`bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Defaults to `None`.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_start_token_id=None: (`optional`) int
If an encoder-decoder model starts decoding with a different token than BOS.
Defaults to `None` and is changed to `BOS` later.
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
model_specific_kwargs: (`optional`) dict
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Return:
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40) # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
"""
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
)
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
else:
batch_size = 1
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert input_ids is not None or (
isinstance(bos_token_id, int) and bos_token_id >= 0
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
), "`no_repeat_ngram_size` should be a positive integer."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictly positive integer."
assert (
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = torch.full(
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
)
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created
if pad_token_id is None and eos_token_id is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_id
# current position and vocab size
if hasattr(self.config, "vocab_size"):
vocab_size = self.config.vocab_size
elif (
self.config.is_encoder_decoder
and hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "vocab_size")
):
vocab_size = self.config.decoder.vocab_size
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else:
effective_batch_size = batch_size
effective_batch_mult = 1
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
# create empty decoder_input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
cur_len = 1
assert (
batch_size == encoder_outputs[0].shape[0]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
torch.arange(batch_size)
.view(-1, 1)
.repeat(1, num_beams * effective_batch_mult)
.view(-1)
.to(input_ids.device)
)
# expand encoder_outputs
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
else:
encoder_outputs = None
cur_len = input_ids.shape[-1]
assert (
cur_len < max_length
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
)
else:
output = self._generate_no_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
)
return output
def _generate_no_beam_search(
self,
input_ids,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = (encoder_outputs, None) if encoder_outputs is not None else None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
scores = self.postprocess_next_token_scores(
scores=next_token_logits,
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
cur_len=cur_len,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=1,
)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
# Sample
probs = F.softmax(next_token_logscores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1)
# update generations and finished sentences
if eos_token_id is not None:
# pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
# add token and increase length by one
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return input_ids
def _generate_beam_search(
self,
input_ids,
cur_len,
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example with beam search.
"""
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = (encoder_outputs, None) if encoder_outputs is not None else None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
scores = self.postprocess_next_token_scores(
scores=scores,
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
cur_len=cur_len,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=num_beams,
)
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size)
)
if do_sample:
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Temperature
if temperature != 1.0:
_scores = _scores / temperature
# Top-p/top-k filtering
_scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
probs = F.softmax(_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else:
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
# next batch beam content
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence, add a pad token
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content, this will get added to next_batch_beam
next_sent_beam = []
# next tokens for this sentence
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and token IDs
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), beam_token_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams:
break
# Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
)
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
# stop when we are done with each sentence
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and update current length
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
cur_len = cur_len + 1
# re-order internal states
if past is not None:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx in range(batch_size):
if done[batch_idx]:
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_id is not None and all(
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
):
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
)
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(num_beams):
effective_beam_id = batch_idx * num_beams + beam_id
final_score = beam_scores[effective_beam_id].item()
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
sent_lengths = input_ids.new(output_batch_size)
best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
effective_batch_idx = output_num_return_sequences_per_batch * i + j
best_hyp = sorted_hyps.pop()[1]
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# shorter batches are padded
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded
@staticmethod
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_input_ids):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def top_k_top_p_filtering(
logits: Tensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> Tensor:
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
================================================
FILE: bert/modeling_bert.py
================================================
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model. """
import logging
import math
import os
import warnings
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_new, swish
from .configuration_bert import BertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
_TOKENIZER_FOR_DOC = "BertTokenizer"
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bert-base-uncased",
"bert-large-uncased",
"bert-base-cased",
"bert-large-cased",
"bert-base-multilingual-uncased",
"bert-base-multilingual-cased",
"bert-base-chinese",
"bert-base-german-cased",
"bert-large-uncased-whole-word-masking",
"bert-large-cased-whole-word-masking",
"bert-large-uncased-whole-word-masking-finetuned-squad",
"bert-large-cased-whole-word-masking-finetuned-squad",
"bert-base-cased-finetuned-mrpc",
"bert-base-german-dbmdz-cased",
"bert-base-german-dbmdz-uncased",
"cl-tohoku/bert-base-japanese",
"cl-tohoku/bert-base-japanese-whole-word-masking",
"cl-tohoku/bert-base-japanese-char",
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
"TurkuNLP/bert-base-finnish-cased-v1",
"TurkuNLP/bert-base-finnish-uncased-v1",
"wietsedv/bert-base-dutch-cased",
# See all BERT models at https://huggingface.co/models?filter=bert
]
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info("Skipping {}".format("/".join(name)))
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
if self.is_decoder:
self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + outputs
return outputs
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
BERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
"""
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
.. _`Attention is all you need`:
https://arxiv.org/abs/1706.03762
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings(
"""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence.
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
>>> from transformers import BertTokenizer, BertForPreTraining
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_scores, seq_relationship_scores = outputs[:2]
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning,
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
gitextract_240sk8kr/ ├── LICENSE ├── README.md ├── args.py ├── bert/ │ ├── activations.py │ ├── configuration_bert.py │ ├── configuration_utils.py │ ├── file_utils.py │ ├── generation_utils.py │ ├── modeling_bert.py │ ├── modeling_utils.py │ ├── tokenization_bert.py │ ├── tokenization_utils.py │ └── tokenization_utils_base.py ├── data/ │ └── dataset_refer_bert.py ├── demo_inference.py ├── lib/ │ ├── _utils.py │ ├── backbone.py │ ├── mask_predictor.py │ ├── mmcv_custom/ │ │ ├── __init__.py │ │ └── checkpoint.py │ └── segmentation.py ├── refer/ │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── data/ │ │ └── README.md │ ├── evaluation/ │ │ ├── __init__.py │ │ ├── bleu/ │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── bleu.py │ │ │ └── bleu_scorer.py │ │ ├── cider/ │ │ │ ├── __init__.py │ │ │ ├── cider.py │ │ │ └── cider_scorer.py │ │ ├── meteor/ │ │ │ ├── __init__.py │ │ │ └── meteor.py │ │ ├── readme.txt │ │ ├── refEvaluation.py │ │ ├── rouge/ │ │ │ ├── __init__.py │ │ │ └── rouge.py │ │ └── tokenizer/ │ │ ├── __init__.py │ │ ├── ptbtokenizer.py │ │ └── stanford-corenlp-3.4.1.jar │ ├── external/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── _mask.pyx │ │ ├── mask.py │ │ ├── maskApi.c │ │ └── maskApi.h │ ├── pyEvalDemo.ipynb │ ├── pyReferDemo.ipynb │ ├── refer.py │ ├── setup.py │ └── test/ │ ├── sample_expressions_testA.json │ └── sample_expressions_testB.json ├── requirements.txt ├── test.py ├── train.py ├── transforms.py └── utils.py
SYMBOL INDEX (583 symbols across 33 files)
FILE: args.py
function get_parser (line 4) | def get_parser():
FILE: bert/activations.py
function swish (line 11) | def swish(x):
function _gelu_python (line 15) | def _gelu_python(x):
function gelu_new (line 25) | def gelu_new(x):
function gelu_fast (line 38) | def gelu_fast(x):
function get_activation (line 52) | def get_activation(activation_string):
FILE: bert/configuration_bert.py
class BertConfig (line 53) | class BertConfig(PretrainedConfig):
method __init__ (line 111) | def __init__(
FILE: bert/configuration_utils.py
class PretrainedConfig (line 31) | class PretrainedConfig(object):
method __init__ (line 56) | def __init__(self, **kwargs):
method num_labels (line 118) | def num_labels(self):
method num_labels (line 122) | def num_labels(self, num_labels):
method save_pretrained (line 126) | def save_pretrained(self, save_directory):
method from_pretrained (line 145) | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "...
method get_config_dict (line 204) | def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs)...
method from_dict (line 269) | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
method from_json_file (line 307) | def from_json_file(cls, json_file: str) -> "PretrainedConfig":
method _dict_from_json_file (line 323) | def _dict_from_json_file(cls, json_file: str):
method __eq__ (line 328) | def __eq__(self, other):
method __repr__ (line 331) | def __repr__(self):
method to_diff_dict (line 334) | def to_diff_dict(self):
method to_dict (line 357) | def to_dict(self):
method to_json_string (line 369) | def to_json_string(self, use_diff=True):
method to_json_file (line 386) | def to_json_file(self, json_file_path, use_diff=True):
method update (line 399) | def update(self, config_dict: Dict):
FILE: bert/file_utils.py
function is_torch_available (line 131) | def is_torch_available():
function is_tf_available (line 135) | def is_tf_available():
function is_torch_tpu_available (line 139) | def is_torch_tpu_available():
function is_psutil_available (line 143) | def is_psutil_available():
function is_py3nvml_available (line 147) | def is_py3nvml_available():
function is_apex_available (line 151) | def is_apex_available():
function add_start_docstrings (line 155) | def add_start_docstrings(*docstr):
function add_start_docstrings_to_callable (line 163) | def add_start_docstrings_to_callable(*docstr):
function add_end_docstrings (line 181) | def add_end_docstrings(*docstr):
function add_code_sample_docstrings (line 417) | def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint...
function is_remote_url (line 446) | def is_remote_url(url_or_filename):
function hf_bucket_url (line 451) | def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
function url_to_filename (line 475) | def url_to_filename(url, etag=None):
function filename_to_url (line 499) | def filename_to_url(filename, cache_dir=None):
function cached_path (line 525) | def cached_path(
function http_get (line 617) | def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Un...
function get_from_cache (line 650) | def get_from_cache(
class cached_property (line 764) | class cached_property(property):
method __get__ (line 773) | def __get__(self, obj, objtype=None):
function torch_required (line 787) | def torch_required(func):
function tf_required (line 799) | def tf_required(func):
FILE: bert/generation_utils.py
class GenerationMixin (line 28) | class GenerationMixin:
method prepare_inputs_for_generation (line 33) | def prepare_inputs_for_generation(self, input_ids, **kwargs):
method adjust_logits_during_generation (line 36) | def adjust_logits_during_generation(self, logits, **kwargs):
method _use_cache (line 39) | def _use_cache(self, outputs, use_cache):
method enforce_repetition_penalty_ (line 47) | def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, p...
method postprocess_next_token_scores (line 57) | def postprocess_next_token_scores(
method generate (line 101) | def generate(
method _generate_no_beam_search (line 485) | def _generate_no_beam_search(
method _generate_beam_search (line 585) | def _generate_beam_search(
method _reorder_cache (line 844) | def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
function calc_banned_ngram_tokens (line 848) | def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_...
function calc_banned_bad_words_ids (line 871) | def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_i...
function top_k_top_p_filtering (line 907) | def top_k_top_p_filtering(
class BeamHypotheses (line 948) | class BeamHypotheses(object):
method __init__ (line 949) | def __init__(self, num_beams, max_length, length_penalty, early_stoppi...
method __len__ (line 960) | def __len__(self):
method add (line 966) | def add(self, hyp, sum_logprobs):
method is_done (line 980) | def is_done(self, best_sum_logprobs, cur_len):
FILE: bert/modeling_bert.py
function load_tf_weights_in_bert (line 66) | def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
function mish (line 138) | def mish(x):
class BertEmbeddings (line 148) | class BertEmbeddings(nn.Module):
method __init__ (line 152) | def __init__(self, config):
method forward (line 163) | def forward(self, input_ids=None, token_type_ids=None, position_ids=No...
class BertSelfAttention (line 188) | class BertSelfAttention(nn.Module):
method __init__ (line 189) | def __init__(self, config):
method transpose_for_scores (line 207) | def transpose_for_scores(self, x):
method forward (line 212) | def forward(
class BertSelfOutput (line 266) | class BertSelfOutput(nn.Module):
method __init__ (line 267) | def __init__(self, config):
method forward (line 273) | def forward(self, hidden_states, input_tensor):
class BertAttention (line 280) | class BertAttention(nn.Module):
method __init__ (line 281) | def __init__(self, config):
method prune_heads (line 287) | def prune_heads(self, heads):
method forward (line 305) | def forward(
class BertIntermediate (line 322) | class BertIntermediate(nn.Module):
method __init__ (line 323) | def __init__(self, config):
method forward (line 331) | def forward(self, hidden_states):
class BertOutput (line 337) | class BertOutput(nn.Module):
method __init__ (line 338) | def __init__(self, config):
method forward (line 344) | def forward(self, hidden_states, input_tensor):
class BertLayer (line 351) | class BertLayer(nn.Module):
method __init__ (line 352) | def __init__(self, config):
method forward (line 361) | def forward(
class BertEncoder (line 394) | class BertEncoder(nn.Module):
method __init__ (line 395) | def __init__(self, config):
method forward (line 400) | def forward(
class BertPooler (line 458) | class BertPooler(nn.Module):
method __init__ (line 459) | def __init__(self, config):
method forward (line 464) | def forward(self, hidden_states):
class BertPredictionHeadTransform (line 473) | class BertPredictionHeadTransform(nn.Module):
method __init__ (line 474) | def __init__(self, config):
method forward (line 483) | def forward(self, hidden_states):
class BertLMPredictionHead (line 490) | class BertLMPredictionHead(nn.Module):
method __init__ (line 491) | def __init__(self, config):
method forward (line 504) | def forward(self, hidden_states):
class BertOnlyMLMHead (line 510) | class BertOnlyMLMHead(nn.Module):
method __init__ (line 511) | def __init__(self, config):
method forward (line 515) | def forward(self, sequence_output):
class BertOnlyNSPHead (line 520) | class BertOnlyNSPHead(nn.Module):
method __init__ (line 521) | def __init__(self, config):
method forward (line 525) | def forward(self, pooled_output):
class BertPreTrainingHeads (line 530) | class BertPreTrainingHeads(nn.Module):
method __init__ (line 531) | def __init__(self, config):
method forward (line 536) | def forward(self, sequence_output, pooled_output):
class BertPreTrainedModel (line 542) | class BertPreTrainedModel(PreTrainedModel):
method _init_weights (line 551) | def _init_weights(self, module):
class BertModel (line 627) | class BertModel(BertPreTrainedModel):
method __init__ (line 644) | def __init__(self, config):
method get_input_embeddings (line 654) | def get_input_embeddings(self):
method set_input_embeddings (line 657) | def set_input_embeddings(self, value):
method _prune_heads (line 660) | def _prune_heads(self, heads_to_prune):
method forward (line 670) | def forward(
class BertForPreTraining (line 778) | class BertForPreTraining(BertPreTrainedModel):
method __init__ (line 779) | def __init__(self, config):
method get_output_embeddings (line 787) | def get_output_embeddings(self):
method forward (line 791) | def forward(
class BertLMHeadModel (line 894) | class BertLMHeadModel(BertPreTrainedModel):
method __init__ (line 895) | def __init__(self, config):
method get_output_embeddings (line 904) | def get_output_embeddings(self):
method forward (line 908) | def forward(
method prepare_inputs_for_generation (line 994) | def prepare_inputs_for_generation(self, input_ids, attention_mask=None...
class BertForMaskedLM (line 1005) | class BertForMaskedLM(BertPreTrainedModel):
method __init__ (line 1006) | def __init__(self, config):
method get_output_embeddings (line 1017) | def get_output_embeddings(self):
method forward (line 1022) | def forward(
method prepare_inputs_for_generation (line 1098) | def prepare_inputs_for_generation(self, input_ids, attention_mask=None...
class BertForNextSentencePrediction (line 1116) | class BertForNextSentencePrediction(BertPreTrainedModel):
method __init__ (line 1117) | def __init__(self, config):
method forward (line 1126) | def forward(
class BertForSequenceClassification (line 1208) | class BertForSequenceClassification(BertPreTrainedModel):
method __init__ (line 1209) | def __init__(self, config):
method forward (line 1221) | def forward(
class BertForMultipleChoice (line 1295) | class BertForMultipleChoice(BertPreTrainedModel):
method __init__ (line 1296) | def __init__(self, config):
method forward (line 1307) | def forward(
class BertForTokenClassification (line 1389) | class BertForTokenClassification(BertPreTrainedModel):
method __init__ (line 1390) | def __init__(self, config):
method forward (line 1402) | def forward(
class BertForQuestionAnswering (line 1477) | class BertForQuestionAnswering(BertPreTrainedModel):
method __init__ (line 1478) | def __init__(self, config):
method forward (line 1489) | def forward(
FILE: bert/modeling_utils.py
class Identity (line 48) | class Identity(nn.Module):
method __init__ (line 52) | def __init__(self, *args, **kwargs):
method forward (line 55) | def forward(self, input):
function find_pruneable_heads_and_indices (line 59) | def find_pruneable_heads_and_indices(
class ModuleUtilsMixin (line 73) | class ModuleUtilsMixin:
method num_parameters (line 78) | def num_parameters(self, only_trainable: bool = False) -> int:
method _hook_rss_memory_pre_forward (line 86) | def _hook_rss_memory_pre_forward(module, *args, **kwargs):
method _hook_rss_memory_post_forward (line 98) | def _hook_rss_memory_post_forward(module, *args, **kwargs):
method add_memory_hooks (line 111) | def add_memory_hooks(self):
method reset_memory_hooks_state (line 120) | def reset_memory_hooks_state(self):
method device (line 127) | def device(self) -> device:
method dtype (line 145) | def dtype(self) -> dtype:
method invert_attention_mask (line 162) | def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Ten...
method get_extended_attention_mask (line 188) | def get_extended_attention_mask(self, attention_mask: Tensor, input_sh...
method get_head_mask (line 232) | def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_...
method _convert_head_mask_to_5d (line 253) | def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
class PreTrainedModel (line 265) | class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
method dummy_inputs (line 285) | def dummy_inputs(self):
method __init__ (line 293) | def __init__(self, config, *inputs, **kwargs):
method base_model (line 307) | def base_model(self):
method get_input_embeddings (line 310) | def get_input_embeddings(self):
method set_input_embeddings (line 324) | def set_input_embeddings(self, value: nn.Module):
method get_output_embeddings (line 338) | def get_output_embeddings(self):
method tie_weights (line 348) | def tie_weights(self):
method _tie_or_clone_weights (line 358) | def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
method resize_token_embeddings (line 376) | def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
method _resize_token_embeddings (line 403) | def _resize_token_embeddings(self, new_num_tokens):
method _get_resized_embeddings (line 409) | def _get_resized_embeddings(
method init_weights (line 447) | def init_weights(self):
method prune_heads (line 459) | def prune_heads(self, heads_to_prune: Dict):
method save_pretrained (line 474) | def save_pretrained(self, save_directory):
method from_pretrained (line 510) | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, *...
class Conv1D (line 815) | class Conv1D(nn.Module):
method __init__ (line 816) | def __init__(self, nf, nx):
method forward (line 827) | def forward(self, x):
class PoolerStartLogits (line 834) | class PoolerStartLogits(nn.Module):
method __init__ (line 837) | def __init__(self, config):
method forward (line 841) | def forward(self, hidden_states, p_mask=None):
class PoolerEndLogits (line 858) | class PoolerEndLogits(nn.Module):
method __init__ (line 862) | def __init__(self, config):
method forward (line 869) | def forward(self, hidden_states, start_states=None, start_positions=No...
class PoolerAnswerClass (line 905) | class PoolerAnswerClass(nn.Module):
method __init__ (line 908) | def __init__(self, config):
method forward (line 914) | def forward(self, hidden_states, start_states=None, start_positions=No...
class SQuADHead (line 952) | class SQuADHead(nn.Module):
method __init__ (line 993) | def __init__(self, config):
method forward (line 1002) | def forward(
class SequenceSummary (line 1069) | class SequenceSummary(nn.Module):
method __init__ (line 1085) | def __init__(self, config: PretrainedConfig):
method forward (line 1114) | def forward(self, hidden_states, cls_index=None):
function prune_linear_layer (line 1146) | def prune_linear_layer(layer, index, dim=0):
function prune_conv1d_layer (line 1171) | def prune_conv1d_layer(layer, index, dim=1):
function prune_layer (line 1195) | def prune_layer(layer, index, dim=None):
function apply_chunking_to_forward (line 1208) | def apply_chunking_to_forward(
FILE: bert/tokenization_bert.py
function load_vocab (line 97) | def load_vocab(vocab_file):
function whitespace_tokenize (line 108) | def whitespace_tokenize(text):
class BertTokenizer (line 117) | class BertTokenizer(PreTrainedTokenizer):
method __init__ (line 161) | def __init__(
method vocab_size (line 199) | def vocab_size(self):
method get_vocab (line 202) | def get_vocab(self):
method _tokenize (line 205) | def _tokenize(self, text):
method _convert_token_to_id (line 219) | def _convert_token_to_id(self, token):
method _convert_id_to_token (line 223) | def _convert_id_to_token(self, index):
method convert_tokens_to_string (line 227) | def convert_tokens_to_string(self, tokens):
method build_inputs_with_special_tokens (line 232) | def build_inputs_with_special_tokens(
method get_special_tokens_mask (line 258) | def get_special_tokens_mask(
method create_token_type_ids_from_sequences (line 289) | def create_token_type_ids_from_sequences(
method save_vocabulary (line 319) | def save_vocabulary(self, vocab_path):
class BasicTokenizer (line 348) | class BasicTokenizer(object):
method __init__ (line 351) | def __init__(self, do_lower_case=True, never_split=None, tokenize_chin...
method tokenize (line 371) | def tokenize(self, text, never_split=None):
method _run_strip_accents (line 403) | def _run_strip_accents(self, text):
method _run_split_on_punc (line 414) | def _run_split_on_punc(self, text, never_split=None):
method _tokenize_chinese_chars (line 436) | def _tokenize_chinese_chars(self, text):
method _is_chinese_char (line 449) | def _is_chinese_char(self, cp):
method _clean_text (line 473) | def _clean_text(self, text):
class WordpieceTokenizer (line 487) | class WordpieceTokenizer(object):
method __init__ (line 490) | def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
method tokenize (line 495) | def tokenize(self, text):
FILE: bert/tokenization_utils.py
function _is_whitespace (line 47) | def _is_whitespace(char):
function _is_control (line 59) | def _is_control(char):
function _is_punctuation (line 71) | def _is_punctuation(char):
function _is_end_of_word (line 86) | def _is_end_of_word(text):
function _is_start_of_word (line 92) | def _is_start_of_word(text):
class PreTrainedTokenizer (line 98) | class PreTrainedTokenizer(PreTrainedTokenizerBase):
method __init__ (line 156) | def __init__(self, **kwargs):
method is_fast (line 166) | def is_fast(self) -> bool:
method vocab_size (line 170) | def vocab_size(self) -> int:
method get_vocab (line 174) | def get_vocab(self):
method get_added_vocab (line 178) | def get_added_vocab(self) -> Dict[str, int]:
method __len__ (line 181) | def __len__(self):
method _add_tokens (line 185) | def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], ...
method num_special_tokens_to_add (line 237) | def num_special_tokens_to_add(self, pair=False):
method tokenize (line 256) | def tokenize(self, text: TextInput, **kwargs):
method _tokenize (line 367) | def _tokenize(self, text, **kwargs):
method convert_tokens_to_ids (line 376) | def convert_tokens_to_ids(self, tokens):
method _convert_token_to_id_with_added_voc (line 391) | def _convert_token_to_id_with_added_voc(self, token):
method _convert_token_to_id (line 399) | def _convert_token_to_id(self, token):
method _encode_plus (line 402) | def _encode_plus(
method _batch_encode_plus (line 476) | def _batch_encode_plus(
method _batch_prepare_for_model (line 560) | def _batch_prepare_for_model(
method prepare_for_tokenization (line 623) | def prepare_for_tokenization(self, text: str, is_pretokenized=False, *...
method get_special_tokens_mask (line 631) | def get_special_tokens_mask(
method convert_ids_to_tokens (line 650) | def convert_ids_to_tokens(
method _convert_id_to_token (line 675) | def _convert_id_to_token(self, index: int) -> str:
method convert_tokens_to_string (line 678) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
method decode (line 685) | def decode(
method save_vocabulary (line 715) | def save_vocabulary(self, save_directory) -> Tuple[str]:
FILE: bert/tokenization_utils_base.py
class ExplicitEnum (line 74) | class ExplicitEnum(Enum):
method _missing_ (line 79) | def _missing_(cls, value):
class TruncationStrategy (line 86) | class TruncationStrategy(ExplicitEnum):
class PaddingStrategy (line 93) | class PaddingStrategy(ExplicitEnum):
class TensorType (line 99) | class TensorType(ExplicitEnum):
class CharSpan (line 105) | class CharSpan(NamedTuple):
class TokenSpan (line 117) | class TokenSpan(NamedTuple):
class BatchEncoding (line 129) | class BatchEncoding(UserDict):
method __init__ (line 145) | def __init__(
method is_fast (line 162) | def is_fast(self):
method __getitem__ (line 169) | def __getitem__(self, item: Union[int, str]) -> EncodingFast:
method __getattr__ (line 183) | def __getattr__(self, item: str):
method __getstate__ (line 189) | def __getstate__(self):
method __setstate__ (line 192) | def __setstate__(self, state):
method keys (line 199) | def keys(self):
method values (line 202) | def values(self):
method items (line 205) | def items(self):
method encodings (line 213) | def encodings(self) -> Optional[List[EncodingFast]]:
method tokens (line 221) | def tokens(self, batch_index: int = 0) -> List[str]:
method words (line 226) | def words(self, batch_index: int = 0) -> List[Optional[int]]:
method token_to_word (line 231) | def token_to_word(self, batch_or_token_index: int, token_index: Option...
method word_to_tokens (line 272) | def word_to_tokens(self, batch_or_word_index: int, word_index: Optiona...
method token_to_chars (line 321) | def token_to_chars(self, batch_or_token_index: int, token_index: Optio...
method char_to_token (line 362) | def char_to_token(self, batch_or_char_index: int, char_index: Optional...
method word_to_chars (line 398) | def word_to_chars(self, batch_or_word_index: int, word_index: Optional...
method char_to_word (line 439) | def char_to_word(self, batch_or_char_index: int, char_index: Optional[...
method convert_to_tensors (line 476) | def convert_to_tensors(self, tensor_type: Union[None, str, TensorType]...
method to (line 522) | def to(self, device: str):
class SpecialTokensMixin (line 565) | class SpecialTokensMixin:
method __init__ (line 583) | def __init__(self, verbose=True, **kwargs):
method sanitize_special_tokens (line 610) | def sanitize_special_tokens(self) -> int:
method add_special_tokens (line 619) | def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str,...
method add_tokens (line 680) | def add_tokens(self, new_tokens: Union[str, AddedToken, List[str], Lis...
method bos_token (line 718) | def bos_token(self):
method eos_token (line 726) | def eos_token(self):
method unk_token (line 734) | def unk_token(self):
method sep_token (line 742) | def sep_token(self):
method pad_token (line 750) | def pad_token(self):
method cls_token (line 758) | def cls_token(self):
method mask_token (line 766) | def mask_token(self):
method additional_special_tokens (line 774) | def additional_special_tokens(self):
method bos_token (line 782) | def bos_token(self, value):
method eos_token (line 786) | def eos_token(self, value):
method unk_token (line 790) | def unk_token(self, value):
method sep_token (line 794) | def sep_token(self, value):
method pad_token (line 798) | def pad_token(self, value):
method cls_token (line 802) | def cls_token(self, value):
method mask_token (line 806) | def mask_token(self, value):
method additional_special_tokens (line 810) | def additional_special_tokens(self, value):
method bos_token_id (line 814) | def bos_token_id(self):
method eos_token_id (line 821) | def eos_token_id(self):
method unk_token_id (line 828) | def unk_token_id(self):
method sep_token_id (line 835) | def sep_token_id(self):
method pad_token_id (line 842) | def pad_token_id(self):
method pad_token_type_id (line 849) | def pad_token_type_id(self):
method cls_token_id (line 854) | def cls_token_id(self):
method mask_token_id (line 861) | def mask_token_id(self):
method additional_special_tokens_ids (line 868) | def additional_special_tokens_ids(self):
method special_tokens_map (line 873) | def special_tokens_map(self):
method special_tokens_map_extended (line 887) | def special_tokens_map_extended(self):
method all_special_tokens (line 902) | def all_special_tokens(self):
method all_special_tokens_extended (line 912) | def all_special_tokens_extended(self):
method all_special_ids (line 926) | def all_special_ids(self):
class PreTrainedTokenizerBase (line 1015) | class PreTrainedTokenizerBase(SpecialTokensMixin):
method __init__ (line 1029) | def __init__(self, **kwargs):
method max_len (line 1049) | def max_len(self) -> int:
method max_len_single_sentence (line 1056) | def max_len_single_sentence(self) -> int:
method max_len_sentences_pair (line 1060) | def max_len_sentences_pair(self) -> int:
method max_len_single_sentence (line 1064) | def max_len_single_sentence(self, value) -> int:
method max_len_sentences_pair (line 1076) | def max_len_sentences_pair(self, value) -> int:
method from_pretrained (line 1088) | def from_pretrained(cls, *inputs, **kwargs):
method _from_pretrained (line 1143) | def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs,...
method save_pretrained (line 1334) | def save_pretrained(self, save_directory) -> Tuple[str]:
method encode (line 1389) | def encode(
method num_special_tokens_to_add (line 1430) | def num_special_tokens_to_add(self, pair: bool = False) -> int:
method _get_padding_truncation_strategies (line 1433) | def _get_padding_truncation_strategies(
method __call__ (line 1551) | def __call__(
method encode_plus (line 1673) | def encode_plus(
method _encode_plus (line 1740) | def _encode_plus(
method batch_encode_plus (line 1764) | def batch_encode_plus(
method _batch_encode_plus (line 1835) | def _batch_encode_plus(
method pad (line 1864) | def pad(
method create_token_type_ids_from_sequences (line 1964) | def create_token_type_ids_from_sequences(self, token_ids_0: List, toke...
method build_inputs_with_special_tokens (line 1969) | def build_inputs_with_special_tokens(self, token_ids_0: List, token_id...
method prepare_for_model (line 1979) | def prepare_for_model(
method truncate_sequences (line 2103) | def truncate_sequences(
method _pad (line 2183) | def _pad(
method batch_decode (line 2253) | def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[s...
method decode (line 2256) | def decode(
method get_special_tokens_mask (line 2271) | def get_special_tokens_mask(
method clean_up_tokenization (line 2302) | def clean_up_tokenization(out_string: str) -> str:
FILE: data/dataset_refer_bert.py
class ReferDataset (line 24) | class ReferDataset(data.Dataset):
method __init__ (line 26) | def __init__(self,
method get_classes (line 80) | def get_classes(self):
method __len__ (line 83) | def __len__(self):
method __getitem__ (line 86) | def __getitem__(self, index):
FILE: demo_inference.py
class args (line 50) | class args:
function overlay_davis (line 82) | def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1...
FILE: lib/_utils.py
class _LAVTSimpleDecode (line 9) | class _LAVTSimpleDecode(nn.Module):
method __init__ (line 10) | def __init__(self, backbone, classifier):
method forward (line 15) | def forward(self, x, l_feats, l_mask):
class LAVT (line 25) | class LAVT(_LAVTSimpleDecode):
class _LAVTOneSimpleDecode (line 32) | class _LAVTOneSimpleDecode(nn.Module):
method __init__ (line 33) | def __init__(self, backbone, classifier, args):
method forward (line 40) | def forward(self, x, text, l_mask):
class LAVTOne (line 55) | class LAVTOne(_LAVTOneSimpleDecode):
FILE: lib/backbone.py
class Mlp (line 11) | class Mlp(nn.Module):
method __init__ (line 14) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 23) | def forward(self, x):
function window_partition (line 32) | def window_partition(x, window_size):
function window_reverse (line 47) | def window_reverse(windows, window_size, H, W):
class WindowAttention (line 64) | class WindowAttention(nn.Module):
method __init__ (line 78) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 112) | def forward(self, x, mask=None):
class SwinTransformerBlock (line 145) | class SwinTransformerBlock(nn.Module):
method __init__ (line 163) | def __init__(self, dim, num_heads, window_size=7, shift_size=0,
method forward (line 187) | def forward(self, x, mask_matrix):
class PatchMerging (line 247) | class PatchMerging(nn.Module):
method __init__ (line 254) | def __init__(self, dim, norm_layer=nn.LayerNorm):
method forward (line 260) | def forward(self, x, H, W):
class PatchEmbed (line 290) | class PatchEmbed(nn.Module):
method __init__ (line 300) | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=...
method forward (line 314) | def forward(self, x):
class MultiModalSwinTransformer (line 333) | class MultiModalSwinTransformer(nn.Module):
method __init__ (line 334) | def __init__(self,
method _freeze_stages (line 419) | def _freeze_stages(self):
method init_weights (line 436) | def init_weights(self, pretrained=None):
method forward (line 462) | def forward(self, x, l, l_mask):
method train (line 489) | def train(self, mode=True):
class MMBasicLayer (line 495) | class MMBasicLayer(nn.Module):
method __init__ (line 496) | def __init__(self,
method forward (line 557) | def forward(self, x, H, W, l, l_mask):
class PWAM (line 606) | class PWAM(nn.Module):
method __init__ (line 607) | def __init__(self, dim, v_in_channels, l_in_channels, key_channels, va...
method forward (line 627) | def forward(self, x, l, l_mask):
class SpatialImageLanguageAttention (line 643) | class SpatialImageLanguageAttention(nn.Module):
method __init__ (line 644) | def __init__(self, v_in_channels, l_in_channels, key_channels, value_c...
method forward (line 681) | def forward(self, x, l, l_mask):
FILE: lib/mask_predictor.py
class SimpleDecoding (line 7) | class SimpleDecoding(nn.Module):
method __init__ (line 8) | def __init__(self, c4_dims, factor=2):
method forward (line 40) | def forward(self, x_c4, x_c3, x_c2, x_c1):
FILE: lib/mmcv_custom/checkpoint.py
function _get_mmcv_home (line 30) | def _get_mmcv_home():
function load_state_dict (line 41) | def load_state_dict(module, state_dict, strict=False, logger=None):
function load_url_dist (line 110) | def load_url_dist(url, model_dir=None):
function load_pavimodel_dist (line 124) | def load_pavimodel_dist(model_path, map_location=None):
function load_fileclient_dist (line 152) | def load_fileclient_dist(filename, backend, map_location):
function get_torchvision_models (line 173) | def get_torchvision_models():
function get_external_models (line 185) | def get_external_models():
function get_mmcls_models (line 199) | def get_mmcls_models():
function get_deprecated_model_names (line 206) | def get_deprecated_model_names():
function _process_mmcls_checkpoint (line 215) | def _process_mmcls_checkpoint(checkpoint):
function _load_checkpoint (line 226) | def _load_checkpoint(filename, map_location=None):
function load_checkpoint (line 287) | def load_checkpoint(model,
function weights_to_cpu (line 363) | def weights_to_cpu(state_dict):
function _save_to_state_dict (line 378) | def _save_to_state_dict(module, destination, prefix, keep_vars):
function get_state_dict (line 398) | def get_state_dict(module, destination=None, prefix='', keep_vars=False):
function save_checkpoint (line 442) | def save_checkpoint(model, filename, optimizer=None, meta=None):
FILE: lib/segmentation.py
function _segm_lavt (line 11) | def _segm_lavt(pretrained, args):
function _load_model_lavt (line 68) | def _load_model_lavt(pretrained, args):
function lavt (line 73) | def lavt(pretrained='', args=None):
function _segm_lavt_one (line 80) | def _segm_lavt_one(pretrained, args):
function _load_model_lavt_one (line 137) | def _load_model_lavt_one(pretrained, args):
function lavt_one (line 142) | def lavt_one(pretrained='', args=None):
FILE: refer/evaluation/bleu/bleu.py
class Bleu (line 14) | class Bleu:
method __init__ (line 15) | def __init__(self, n=4):
method compute_score (line 21) | def compute_score(self, gts, res):
method method (line 46) | def method(self):
FILE: refer/evaluation/bleu/bleu_scorer.py
function precook (line 23) | def precook(s, n=4, out=False):
function cook_refs (line 35) | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "av...
function cook_test (line 60) | def cook_test(test, (reflen, refmaxcounts), eff=None, n=4):
class BleuScorer (line 85) | class BleuScorer(object):
method copy (line 92) | def copy(self):
method __init__ (line 100) | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
method cook_append (line 109) | def cook_append(self, test, refs):
method ratio (line 122) | def ratio(self, option=None):
method score_ratio (line 126) | def score_ratio(self, option=None):
method score_ratio_str (line 130) | def score_ratio_str(self, option=None):
method reflen (line 133) | def reflen(self, option=None):
method testlen (line 137) | def testlen(self, option=None):
method retest (line 141) | def retest(self, new_test):
method rescore (line 152) | def rescore(self, new_test):
method size (line 157) | def size(self):
method __iadd__ (line 161) | def __iadd__(self, other):
method compatible (line 175) | def compatible(self, other):
method single_reflen (line 178) | def single_reflen(self, option="average"):
method _single_reflen (line 181) | def _single_reflen(self, reflens, option=None, testlen=None):
method recompute_score (line 194) | def recompute_score(self, option=None, verbose=0):
method compute_score (line 198) | def compute_score(self, option=None, verbose=0):
FILE: refer/evaluation/cider/cider.py
class Cider (line 13) | class Cider:
method __init__ (line 18) | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
method compute_score (line 24) | def compute_score(self, gts, res):
method method (line 53) | def method(self):
FILE: refer/evaluation/cider/cider_scorer.py
function precook (line 11) | def precook(s, n=4, out=False):
function cook_refs (line 28) | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
function cook_test (line 38) | def cook_test(test, n=4):
class CiderScorer (line 47) | class CiderScorer(object):
method copy (line 51) | def copy(self):
method __init__ (line 58) | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
method cook_append (line 68) | def cook_append(self, test, refs):
method size (line 78) | def size(self):
method __iadd__ (line 82) | def __iadd__(self, other):
method compute_doc_freq (line 93) | def compute_doc_freq(self):
method compute_cider (line 106) | def compute_cider(self):
method compute_score (line 183) | def compute_score(self, option=None, verbose=0):
FILE: refer/evaluation/meteor/meteor.py
class Meteor (line 15) | class Meteor:
method __init__ (line 17) | def __init__(self):
method compute_score (line 28) | def compute_score(self, gts, res):
method method (line 48) | def method(self):
method _stat (line 51) | def _stat(self, hypothesis_str, reference_list):
method _score (line 58) | def _score(self, hypothesis_str, reference_list):
method __exit__ (line 72) | def __exit__(self):
FILE: refer/evaluation/refEvaluation.py
class RefEvaluation (line 16) | class RefEvaluation:
method __init__ (line 17) | def __init__ (self, refer, Res):
method evaluate (line 28) | def evaluate(self):
method setEval (line 72) | def setEval(self, score, method):
method setRefToEvalRefs (line 75) | def setRefToEvalRefs(self, scores, refIds, method):
method setEvalRefs (line 82) | def setEvalRefs(self):
FILE: refer/evaluation/rouge/rouge.py
function my_lcs (line 13) | def my_lcs(string, sub):
class Rouge (line 36) | class Rouge():
method __init__ (line 41) | def __init__(self):
method calc_score (line 45) | def calc_score(self, candidate, refs):
method compute_score (line 77) | def compute_score(self, gts, res):
method method (line 104) | def method(self):
FILE: refer/evaluation/tokenizer/ptbtokenizer.py
class PTBTokenizer (line 24) | class PTBTokenizer:
method tokenize (line 27) | def tokenize(self, captions_for_image):
FILE: refer/external/maskApi.c
function uint (line 11) | uint umin( uint a, uint b ) { return (a<b) ? a : b; }
function uint (line 12) | uint umax( uint a, uint b ) { return (a>b) ? a : b; }
function rleInit (line 14) | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ) {
function rleFree (line 19) | void rleFree( RLE *R ) {
function rlesInit (line 23) | void rlesInit( RLE **R, siz n ) {
function rlesFree (line 28) | void rlesFree( RLE **R, siz n ) {
function rleEncode (line 32) | void rleEncode( RLE *R, const byte *M, siz h, siz w, siz n ) {
function rleDecode (line 43) | void rleDecode( const RLE *R, byte *M, siz n ) {
function rleMerge (line 49) | void rleMerge( const RLE *R, RLE *M, siz n, int intersect ) {
function rleArea (line 72) | void rleArea( const RLE *R, siz n, uint *a ) {
function rleIou (line 77) | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ) {
function rleNms (line 98) | void rleNms( RLE *dt, siz n, uint *keep, double thr ) {
function bbIou (line 109) | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) {
function bbNms (line 122) | void bbNms( BB dt, siz n, uint *keep, double thr ) {
function rleToBbox (line 133) | void rleToBbox( const RLE *R, BB bb, siz n ) {
function rleFrBbox (line 148) | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ) {
function uintCompare (line 157) | int uintCompare(const void *a, const void *b) {
function rleFrPoly (line 161) | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ) {
function rleFrString (line 217) | void rleFrString( RLE *R, char *s, siz h, siz w ) {
FILE: refer/external/maskApi.h
type uint (line 9) | typedef unsigned int uint;
type siz (line 10) | typedef unsigned long siz;
type byte (line 11) | typedef unsigned char byte;
type RLE (line 13) | typedef struct { siz h, w, m; uint *cnts; } RLE;
FILE: refer/refer.py
class REFER (line 40) | class REFER:
method __init__ (line 42) | def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
method createIndex (line 78) | def createIndex(self):
method getRefIds (line 141) | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
method getAnnIds (line 172) | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
method getImgIds (line 193) | def getImgIds(self, ref_ids=[]):
method getCatIds (line 202) | def getCatIds(self):
method loadRefs (line 205) | def loadRefs(self, ref_ids=[]):
method loadAnns (line 211) | def loadAnns(self, ann_ids=[]):
method loadImgs (line 217) | def loadImgs(self, image_ids=[]):
method loadCats (line 223) | def loadCats(self, cat_ids=[]):
method getRefBox (line 229) | def getRefBox(self, ref_id):
method showRef (line 234) | def showRef(self, ref, seg_box='seg'):
method getMask (line 277) | def getMask(self, ref):
method showMask (line 295) | def showMask(self, ref):
FILE: test.py
function get_dataset (line 21) | def get_dataset(image_set, transform, args):
function evaluate (line 33) | def evaluate(model, data_loader, bert_model, device):
function get_transform (line 92) | def get_transform(args):
function computeIoU (line 101) | def computeIoU(pred_seg, gd_seg):
function main (line 108) | def main(args):
FILE: train.py
function get_dataset (line 26) | def get_dataset(image_set, transform, args):
function IoU (line 39) | def IoU(pred, gt):
function get_transform (line 53) | def get_transform(args):
function criterion (line 62) | def criterion(input, target):
function evaluate (line 67) | def evaluate(model, data_loader, bert_model):
function train_one_epoch (line 126) | def train_one_epoch(model, criterion, optimizer, data_loader, lr_schedul...
function main (line 174) | def main(args):
FILE: transforms.py
function pad_if_smaller (line 10) | def pad_if_smaller(img, size, fill=0):
class Compose (line 20) | class Compose(object):
method __init__ (line 21) | def __init__(self, transforms):
method __call__ (line 24) | def __call__(self, image, target):
class Resize (line 30) | class Resize(object):
method __init__ (line 31) | def __init__(self, h, w):
method __call__ (line 35) | def __call__(self, image, target):
class RandomResize (line 43) | class RandomResize(object):
method __init__ (line 44) | def __init__(self, min_size, max_size=None):
method __call__ (line 50) | def __call__(self, image, target):
class RandomHorizontalFlip (line 59) | class RandomHorizontalFlip(object):
method __init__ (line 60) | def __init__(self, flip_prob):
method __call__ (line 63) | def __call__(self, image, target):
class RandomCrop (line 70) | class RandomCrop(object):
method __init__ (line 71) | def __init__(self, size):
method __call__ (line 74) | def __call__(self, image, target):
class CenterCrop (line 83) | class CenterCrop(object):
method __init__ (line 84) | def __init__(self, size):
method __call__ (line 87) | def __call__(self, image, target):
class ToTensor (line 93) | class ToTensor(object):
method __call__ (line 94) | def __call__(self, image, target):
class RandomAffine (line 100) | class RandomAffine(object):
method __init__ (line 101) | def __init__(self, angle, translate, scale, shear, resample=0, fillcol...
method __call__ (line 109) | def __call__(self, image, target):
class Normalize (line 116) | class Normalize(object):
method __init__ (line 117) | def __init__(self, mean, std):
method __call__ (line 121) | def __call__(self, image, target):
FILE: utils.py
class SmoothedValue (line 16) | class SmoothedValue(object):
method __init__ (line 21) | def __init__(self, window_size=20, fmt=None):
method update (line 29) | def update(self, value, n=1):
method synchronize_between_processes (line 34) | def synchronize_between_processes(self):
method median (line 48) | def median(self):
method avg (line 53) | def avg(self):
method global_avg (line 58) | def global_avg(self):
method max (line 62) | def max(self):
method value (line 66) | def value(self):
method __str__ (line 69) | def __str__(self):
class MetricLogger (line 78) | class MetricLogger(object):
method __init__ (line 79) | def __init__(self, delimiter="\t"):
method update (line 83) | def update(self, **kwargs):
method __getattr__ (line 90) | def __getattr__(self, attr):
method __str__ (line 98) | def __str__(self):
method synchronize_between_processes (line 106) | def synchronize_between_processes(self):
method add_meter (line 110) | def add_meter(self, name, meter):
method log_every (line 113) | def log_every(self, iterable, print_freq, header=None):
function mkdir (line 153) | def mkdir(path):
function setup_for_distributed (line 161) | def setup_for_distributed(is_master):
function is_dist_avail_and_initialized (line 176) | def is_dist_avail_and_initialized():
function get_world_size (line 184) | def get_world_size():
function get_rank (line 190) | def get_rank():
function is_main_process (line 196) | def is_main_process():
function save_on_master (line 200) | def save_on_master(*args, **kwargs):
function init_distributed_mode (line 205) | def init_distributed_mode(args):
Condensed preview — 59 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,257K chars).
[
{
"path": "LICENSE",
"chars": 35149,
"preview": " GNU GENERAL PUBLIC LICENSE\n Version 3, 29 June 2007\n\n Copyright (C) 2007 Free "
},
{
"path": "README.md",
"chars": 15130,
"preview": "# LAVT: Language-Aware Vision Transformer for Referring Image Segmentation\nWelcome to the official repository for the me"
},
{
"path": "args.py",
"chars": 3685,
"preview": "import argparse\n\n\ndef get_parser():\n parser = argparse.ArgumentParser(description='LAVT training and testing')\n pa"
},
{
"path": "bert/activations.py",
"chars": 1537,
"preview": "import logging\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef sw"
},
{
"path": "bert/configuration_bert.py",
"chars": 8872,
"preview": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018,"
},
{
"path": "bert/configuration_utils.py",
"chars": 19039,
"preview": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018,"
},
{
"path": "bert/file_utils.py",
"chars": 28913,
"preview": "\"\"\"\nUtilities for working with the local dataset cache.\nThis file is adapted from the AllenNLP library at https://github"
},
{
"path": "bert/generation_utils.py",
"chars": 48035,
"preview": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace In"
},
{
"path": "bert/modeling_bert.py",
"chars": 72476,
"preview": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018,"
},
{
"path": "bert/modeling_utils.py",
"chars": 63052,
"preview": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace In"
},
{
"path": "bert/tokenization_bert.py",
"chars": 24008,
"preview": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n#\n# Licensed under th"
},
{
"path": "bert/tokenization_utils.py",
"chars": 32754,
"preview": "# coding=utf-8\n# Copyright 2020 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
},
{
"path": "bert/tokenization_utils_base.py",
"chars": 109570,
"preview": "# coding=utf-8\n# Copyright 2020 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
},
{
"path": "data/dataset_refer_bert.py",
"chars": 4050,
"preview": "import os\nimport sys\nimport torch.utils.data as data\nimport torch\nfrom torchvision import transforms\nfrom torch.autograd"
},
{
"path": "demo_inference.py",
"chars": 4159,
"preview": "image_path = './demo/demo.jpg'\nsentence = 'the most handsome guy'\nweights = './checkpoints/refcoco.pth'\ndevice = 'cuda:0"
},
{
"path": "lib/_utils.py",
"chars": 1870,
"preview": "from collections import OrderedDict\nimport sys\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfr"
},
{
"path": "lib/backbone.py",
"chars": 28031,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nimport nu"
},
{
"path": "lib/mask_predictor.py",
"chars": 2872,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom collections import OrderedDict\n\n\nclass Simpl"
},
{
"path": "lib/mmcv_custom/__init__.py",
"chars": 96,
"preview": "# -*- coding: utf-8 -*-\n\nfrom .checkpoint import load_checkpoint\n\n__all__ = ['load_checkpoint']\n"
},
{
"path": "lib/mmcv_custom/checkpoint.py",
"chars": 19326,
"preview": "# Copyright (c) Open-MMLab. All rights reserved.\nimport io\nimport os\nimport os.path as osp\nimport pkgutil\nimport time\nim"
},
{
"path": "lib/segmentation.py",
"chars": 4826,
"preview": "import torch\nimport torch.nn as nn\nfrom .mask_predictor import SimpleDecoding\nfrom .backbone import MultiModalSwinTransf"
},
{
"path": "refer/LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "refer/Makefile",
"chars": 142,
"preview": "all:\n\t# install pycocotools/mask locally\n\t# copy from https://github.com/pdollar/coco.git\n\tpython setup.py build_ext --i"
},
{
"path": "refer/README.md",
"chars": 3075,
"preview": "## Note\nThis API is able to load all 4 referring expression datasets, i.e., RefClef, RefCOCO, RefCOCO+ and RefCOCOg. \nTh"
},
{
"path": "refer/data/README.md",
"chars": 1377,
"preview": "This directory should contain the following data:\n```\n$DATA_PATH\n├── images\n│ ├── mscoco\n│ └── saiaprtc12\n├── refcoc"
},
{
"path": "refer/evaluation/__init__.py",
"chars": 25,
"preview": "__author__ = 'licheng'\n\n\n"
},
{
"path": "refer/evaluation/bleu/LICENSE",
"chars": 1105,
"preview": "Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam\n\nPermission is hereby granted, free of "
},
{
"path": "refer/evaluation/bleu/__init__.py",
"chars": 21,
"preview": "__author__ = 'tylin'\n"
},
{
"path": "refer/evaluation/bleu/bleu.py",
"chars": 1246,
"preview": "#!/usr/bin/env python\n# \n# File Name : bleu.py\n#\n# Description : Wrapper for BLEU scorer.\n#\n# Creation Date : 06-01-2015"
},
{
"path": "refer/evaluation/bleu/bleu_scorer.py",
"chars": 8703,
"preview": "#!/usr/bin/env python\n\n# bleu_scorer.py\n# David Chiang <chiang@isi.edu>\n\n# Copyright (c) 2004-2006 University of Marylan"
},
{
"path": "refer/evaluation/cider/__init__.py",
"chars": 21,
"preview": "__author__ = 'tylin'\n"
},
{
"path": "refer/evaluation/cider/cider.py",
"chars": 1670,
"preview": "# Filename: cider.py\n#\n# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evalua"
},
{
"path": "refer/evaluation/cider/cider_scorer.py",
"chars": 7694,
"preview": "#!/usr/bin/env python\n# Tsung-Yi Lin <tl483@cornell.edu>\n# Ramakrishna Vedantam <vrama91@vt.edu>\n\nimport copy\nfrom colle"
},
{
"path": "refer/evaluation/meteor/__init__.py",
"chars": 21,
"preview": "__author__ = 'tylin'\n"
},
{
"path": "refer/evaluation/meteor/meteor.py",
"chars": 2779,
"preview": "#!/usr/bin/env python\n\n# Python wrapper for METEOR implementation, by Xinlei Chen\n# Acknowledge Michael Denkowski for th"
},
{
"path": "refer/evaluation/readme.txt",
"chars": 317,
"preview": "This folder contains modified coco-caption evaluation, which is downloaded from https://github.com/tylin/coco-caption.gi"
},
{
"path": "refer/evaluation/refEvaluation.py",
"chars": 4212,
"preview": "from tokenizer.ptbtokenizer import PTBTokenizer\nfrom bleu.bleu import Bleu\nfrom meteor.meteor import Meteor\nfrom rouge.r"
},
{
"path": "refer/evaluation/rouge/__init__.py",
"chars": 23,
"preview": "__author__ = 'vrama91'\n"
},
{
"path": "refer/evaluation/rouge/rouge.py",
"chars": 3643,
"preview": "#!/usr/bin/env python\n# \n# File Name : rouge.py\n#\n# Description : Computes ROUGE-L metric as described by Lin and Hovey "
},
{
"path": "refer/evaluation/tokenizer/__init__.py",
"chars": 21,
"preview": "__author__ = 'hfang'\n"
},
{
"path": "refer/evaluation/tokenizer/ptbtokenizer.py",
"chars": 2797,
"preview": "#!/usr/bin/env python\n# \n# File Name : ptbtokenizer.py\n#\n# Description : Do the PTB Tokenization and remove punctuations"
},
{
"path": "refer/external/README.md",
"chars": 90,
"preview": "The codes inside this folder are copied from pycocotools: https://github.com/pdollar/coco\r"
},
{
"path": "refer/external/__init__.py",
"chars": 21,
"preview": "__author__ = 'tylin'\n"
},
{
"path": "refer/external/_mask.pyx",
"chars": 10697,
"preview": "# distutils: language = c\n# distutils: sources = external/maskApi.c\n\n#**************************************************"
},
{
"path": "refer/external/mask.py",
"chars": 4055,
"preview": "__author__ = 'tsungyi'\n\nimport external._mask as _mask\n\n# Interface for manipulating masks stored in RLE format.\n#\n# RLE"
},
{
"path": "refer/external/maskApi.c",
"chars": 8249,
"preview": "/**************************************************************************\n* Microsoft COCO Toolbox. version 2.0\n*"
},
{
"path": "refer/external/maskApi.h",
"chars": 2176,
"preview": "/**************************************************************************\n* Microsoft COCO Toolbox. version 2.0\n*"
},
{
"path": "refer/pyEvalDemo.ipynb",
"chars": 145856,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {\n \"collapsed\": false\n },\n \"out"
},
{
"path": "refer/pyReferDemo.ipynb",
"chars": 267368,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {\n \"collapsed\": false\n },\n \"out"
},
{
"path": "refer/refer.py",
"chars": 12633,
"preview": "\"\"\"\nThis interface provides access to four datasets:\n1) refclef\n2) refcoco\n3) refcoco+\n4) refcocog\nsplit by unc and goog"
},
{
"path": "refer/setup.py",
"chars": 609,
"preview": "from distutils.core import setup\nfrom Cython.Build import cythonize\nfrom distutils.extension import Extension\nimport num"
},
{
"path": "refer/test/sample_expressions_testA.json",
"chars": 76053,
"preview": "{\"predictions\":[{\"sent\":\"man in black\",\"ref_id\":47},{\"sent\":\"person on right\",\"ref_id\":109},{\"sent\":\"woman in red\",\"ref_"
},
{
"path": "refer/test/sample_expressions_testB.json",
"chars": 72428,
"preview": "{\"predictions\":[{\"sent\":\"car on left\",\"ref_id\":25},{\"sent\":\"car on left\",\"ref_id\":26},{\"sent\":\"top sandwich\",\"ref_id\":27"
},
{
"path": "requirements.txt",
"chars": 167,
"preview": "requests\nfilelock\ntqdm\ntimm\nmmcv-full==1.3.12\nmmsegmentation==0.17.0\nftfy\nregex\nscipy\nscikit-image\npycocotools==2.0.2\nop"
},
{
"path": "test.py",
"chars": 4995,
"preview": "import datetime\nimport os\nimport time\n\nimport torch\nimport torch.utils.data\nfrom torch import nn\n\nfrom bert.modeling_ber"
},
{
"path": "train.py",
"chars": 12354,
"preview": "import datetime\nimport os\nimport time\n\nimport torch\nimport torch.utils.data\nfrom torch import nn\n\nfrom functools import "
},
{
"path": "transforms.py",
"chars": 3916,
"preview": "import numpy as np\nfrom PIL import Image\nimport random\n\nimport torch\nfrom torchvision import transforms as T\nfrom torchv"
},
{
"path": "utils.py",
"chars": 6164,
"preview": "from __future__ import print_function\nfrom collections import defaultdict, deque\nimport datetime\nimport math\nimport time"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the yz93/LAVT-RIS GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 59 files (1.1 MB), approximately 468.8k tokens, and a symbol index with 583 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.