Full Code of BlinkDL/RWKV-LM for AI

main c61a66cf725f cached
139 files
7.7 MB
2.0M tokens
1025 symbols
1 requests
Download .txt
Showing preview only (8,080K chars total). Download the full file or copy to clipboard to get everything.
Repository: BlinkDL/RWKV-LM
Branch: main
Commit: c61a66cf725f
Files: 139
Total size: 7.7 MB

Directory structure:
gitextract_nqr7ihm_/

├── .github/
│   └── FUNDING.yml
├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── RWKV-8.md
├── RWKV-v1/
│   ├── src/
│   │   ├── __init__.py
│   │   ├── model.py
│   │   ├── trainer.py
│   │   └── utils.py
│   └── train.py
├── RWKV-v2-RNN/
│   ├── cuda/
│   │   ├── timex_cuda.cu
│   │   └── timex_op.cpp
│   ├── enwik8-vocab.json
│   ├── run.py
│   ├── src/
│   │   ├── model.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   └── train.py
├── RWKV-v3/
│   ├── cuda/
│   │   ├── timex_cuda.cu
│   │   └── timex_op.cpp
│   ├── run.py
│   ├── src/
│   │   ├── model.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── train.py
│   └── verify.py
├── RWKV-v4/
│   ├── 20B_tokenizer.json
│   ├── cuda/
│   │   ├── wkv_cuda.cu
│   │   └── wkv_op.cpp
│   ├── run.py
│   ├── src/
│   │   ├── binidx.py
│   │   ├── model.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── train.py
│   └── verify.py
├── RWKV-v4neo/
│   ├── 20B_tokenizer.json
│   ├── chat.py
│   ├── cuda/
│   │   ├── wkv5_cuda.cu
│   │   ├── wkv5_op.cpp
│   │   ├── wkv_cuda.cu
│   │   ├── wkv_cuda_bf16.cu
│   │   ├── wkv_op.cpp
│   │   └── wkv_op_bf16.cpp
│   ├── img_demoAE.py
│   ├── math_demo/
│   │   └── run.py
│   ├── run.py
│   ├── src/
│   │   ├── __init__.py
│   │   ├── binidx.py
│   │   ├── dataset.py
│   │   ├── model.py
│   │   ├── model_img.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── train.py
│   └── verify.py
├── RWKV-v5/
│   ├── compute_magic_prime.py
│   ├── cuda/
│   │   ├── wkv5_cuda.cu
│   │   ├── wkv5_op.cpp
│   │   ├── wkv6_cuda.cu
│   │   ├── wkv6_op.cpp
│   │   ├── wkv6state_cuda.cu
│   │   ├── wkv6state_op.cpp
│   │   ├── wkv7_cuda.cu
│   │   └── wkv7_op.cpp
│   ├── demo-training-prepare-v7-pile.sh
│   ├── demo-training-prepare.sh
│   ├── demo-training-run-v7-pile.sh
│   ├── demo-training-run.sh
│   ├── demo.jsonl
│   ├── make_data.py
│   ├── rwkv_v6_demo.py
│   ├── src/
│   │   ├── __init__.py
│   │   ├── binidx.py
│   │   ├── dataset.py
│   │   ├── model.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── tokenizer/
│   │   ├── __init__.py
│   │   └── rwkv_tokenizer.py
│   └── train.py
├── RWKV-v6/
│   └── README.md
├── RWKV-v7/
│   ├── README.md
│   ├── cuda/
│   │   ├── wkv7.cu
│   │   ├── wkv7_op.cpp
│   │   ├── wkv7s.cu
│   │   └── wkv7s_op.cpp
│   ├── misc/
│   │   └── lambada_test.jsonl
│   ├── mmlu_dev_dataset/
│   │   ├── data-00000-of-00001.arrow
│   │   ├── dataset_info.json
│   │   └── state.json
│   ├── mmlu_test_dataset/
│   │   ├── data-00000-of-00001.arrow
│   │   ├── dataset_info.json
│   │   └── state.json
│   ├── rwkv_mmlu_eval.py
│   ├── rwkv_v7_demo.py
│   ├── rwkv_v7_demo_fast.py
│   ├── rwkv_v7_demo_rnn.py
│   ├── rwkv_v7_numpy.py
│   ├── rwkv_v7a_demo.py
│   ├── rwkv_v7b_demo.py
│   ├── rwkv_v8_rc00_demo.py
│   ├── rwkv_v8_rc00_hybrid_demo.py
│   └── train_temp/
│       ├── README.md
│       ├── cuda/
│       │   ├── rwkv7_clampw.cpp
│       │   ├── rwkv7_clampw.cu
│       │   ├── wkv7_cuda.cu
│       │   ├── wkv7_cuda_fp32.cu
│       │   ├── wkv7_op.cpp
│       │   └── wkv7_op_fp32.cpp
│       ├── demo-training-prepare-v7-pile.sh
│       ├── demo-training-prepare.sh
│       ├── demo-training-run-v7-pile.sh
│       ├── demo-training-run.sh
│       ├── rwkv7_train_simplified.py
│       ├── src/
│       │   ├── __init__.py
│       │   ├── binidx.py
│       │   ├── dataset.py
│       │   ├── model.py
│       │   └── trainer.py
│       └── train.py
├── RWKV-v8/
│   ├── 251014_rosa_1bit_layer.py
│   ├── 251014_rosa_1bit_train.py
│   ├── 251014_rosa_onlyemb_train.py
│   ├── 251016_rosa_1bit_run.py
│   ├── 251018_rosa_4bit_run.py
│   ├── 251024_rosaQKV_run.py
│   ├── 251105_reverse_run.py
│   ├── 260212_rosa1bitLM_L12.py
│   ├── 260222_rosa4bitLM_L12.py
│   ├── README.md
│   └── cuda/
│       ├── wkv7_cuda.cu
│       └── wkv7_op.cpp
└── Research/
    └── rwkv7-g0-7.2b.md

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/FUNDING.yml
================================================
ko_fi: rwkv_lm


================================================
FILE: .gitignore
================================================
*.txt
*.csv
*.pth
*.xlsb
*.xlsx
*.xls
wandb/
data/
vocab.json
*log/
test/
tools/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "PENG"
  given-names: "Bo"
  orcid: "https://orcid.org/0000-0002-0865-547X"
title: "RWKV-LM"
version: 1.0.0
doi: 10.5281/zenodo.5196577
date-released: 2021-08-13
url: "https://github.com/BlinkDL/RWKV-LM"


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# RWKV: Parallelizable RNN with Transformer-level LLM Performance (pronounced as "RwaKuv" (rʌkuv in IPA), from 4 major params: R W K V)

RWKV website: https://rwkv.com (with 150+ papers training various RWKV models)

RWKV twitter: https://twitter.com/BlinkDL_AI (lastest news)

RWKV discord: https://discord.gg/bDSBUMeFpc

RWKV-7 "Goose" is the strongest **linear-time** & **constant-space** (no kv-cache) & **attention-free** & 100% RNN architecture on this planet at this moment, suitable for LLM and multimodal applications and more (see [rwkv.com](https://rwkv.com)).

RWKV-7 is a [meta-in-context learner](https://raw.githubusercontent.com/BlinkDL/RWKV-LM/main/RWKV-v7.png), test-time-training its state on the context via in-context gradient descent at every token.

RWKV is a [Linux Foundation AI project](https://lfaidata.foundation/projects/rwkv/), so totally free. RWKV runtime is [already in Windows & Office](https://x.com/BlinkDL_AI/status/1831012419508019550).

You are welcome to ask the RWKV community (such as [RWKV discord](https://discord.gg/bDSBUMeFpc)) for advice on upgrading your attention/ssm models to rwkv7 models :)

**Efficient inference project**: https://github.com/BlinkDL/Albatross
* 145+ token/s RWKV-7 7.2B fp16 bsz1 decoding @ RTX5090 (always const speed)
* 10250+ token/s RWKV-7 7.2B fp16 bsz960 decoding @ RTX5090 (always const speed)
* 11289 token/s RWKV-7 7.2B fp16 bsz1 prefill @ RTX5090 (always const speed)

Latest RWKV weights: https://huggingface.co/BlinkDL

GGUF: https://huggingface.co/collections/shoumenchougou/rwkv7-gxx-gguf

**Fast RWKV-7 CUDA kernels (vanilla, state-tuning, state-passing infctx)**: https://github.com/BlinkDL/RWKV-CUDA/tree/main/rwkv7_fast_fused

My current RWKV7 kernel is 2x slower for 0.1/0.4B vs optimized transformer, but you can reach good speed with 7B+. RWKV7 7.2B training on 4x8xH100 ctx8192 zero2+cp = 206k tokens/s.

**RWKV APP**: https://github.com/RWKV-APP/RWKV_APP (local inference for Android / iOS)

**Simplified RWKV-7 training demo**: https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/train_temp/rwkv7_train_simplified.py

**Important** (all shown in rwkv7_train_simplified.py):
* Use PreLN LayerNorm (instead of RMSNorm) for RWKV. I think it's related to better initial state, because I am not using trainable initial state (found it useless when using LayerNorm).
* Only apply weight decay to large matrix parameters (basically projections) in your model instead of all parameters. THIS IS VERY IMPORTANT.
* Use correct initialization.

Improving RNNs: https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-8.md

===

**Please use https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v7/train_temp as RWKV-7 reference implementation**. The default config only requires 1 GPU with 10G VRAM (you can reduce bsz if you have less VRAM), so it's easy to test.

Note FLA RWKV-7 is NOT aligned with reference implementation yet, and you will get less performance.

This is because RWKV-7 is the whole model with carefully set stuffs, including different init / wd / lr for each parameter, so it's readily scalable and very stable (spike-free).

But the price to pay is there is no good simple "RWKV-7 layer" because a pytorch layer can't make sure itself is using correct init and hyperparameters.

So if you need to use RWKV-7 for another task, please study train_temp code (only several hundred lines) and change it to suit you.

See: https://github.com/YS-Tang/RWKV-FLA-comparison

<img width="3318" height="2475" alt="image" src="https://github.com/user-attachments/assets/d9f019c2-a178-4837-8539-3a360c0e6801" />

<img width="2656" height="1956" alt="image" src="https://github.com/user-attachments/assets/871d358b-dcd4-4b86-a04b-45c1bcc910b7" />

===

RWKV-8:

<img src="RWKV-8-ROSA.png">

===

History of RWKV (from v1 to v7): [https://wiki.rwkv.com](https://wiki.rwkv.com/) (note: AI-written. might contain errors)

Gradio Demo 1: https://huggingface.co/spaces/BlinkDL/RWKV-Gradio-1

Gradio Demo 2: https://huggingface.co/spaces/BlinkDL/RWKV-Gradio-2

WebGPU Demo: https://cryscan.github.io/web-rwkv-puzzles/#/chat

===

RWKV-Runner GUI: https://github.com/josStorer/RWKV-Runner/releases

Ai00 Server: https://github.com/Ai00-X/ai00_server

RWKV pip pkg: https://pypi.org/project/rwkv/

PEFT (Lora etc.): https://github.com/JL-er/RWKV-PEFT

RLHF: https://github.com/OpenMOSE/RWKV-LM-RLHF

400+ RWKV projects: https://github.com/search?o=desc&q=rwkv&s=updated&type=Repositories

**Faster RWKV-7 kernels**: https://github.com/johanwind/wind_rwkv

===

RWKV-5/6 Eagle/Finch paper: https://arxiv.org/abs/2404.05892

Chat demo code: https://github.com/BlinkDL/ChatRWKV/blob/main/API_DEMO_CHAT.py

**RWKV-7 demo code**: https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v7

https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo.py (GPT-like mode)

https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo_rnn.py (RNN mode)

https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo_fast.py (Both mode, fastest)

RWKV-6 demo code: https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/rwkv_v6_demo.py

RWKV-6 demo code: https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_v6_demo.py

## HOW TO TRAIN RWKV-7/6/5 on MiniPile (1.5G tokens) ##

For reference, use python 3.10+, torch 2.5+, cuda 12.4+, latest deepspeed, but **keep pytorch-lightning==1.9.5**

**Train RWKV-7:**
```
# you can use latest torch + latest cuda (not limited to cu121)
pip install torch --upgrade --extra-index-url https://download.pytorch.org/whl/cu121
pip install pytorch-lightning==1.9.5 deepspeed wandb ninja --upgrade

# train RWKV-7
cd RWKV-v7/train_temp/ 

# download minipile .bin .idx to train_temp/data first (check demo-training-prepare.sh)
# this will generate the initial weight rwkv-init.pth in out/....../
sh ./demo-training-prepare.sh

# this will load rwkv-init.pth and train the model. you may want to log in to wandb first
sh ./demo-training-run.sh

your out/....../train_log.txt should have losses similar to:
0 4.875856 131.0863 0.00059975 2025-04-24 02:23:42.481256 0
1 4.028621 56.1834 0.00059899 2025-04-24 02:28:16.674463 1
2 3.801625 44.7739 0.00059773 2025-04-24 02:32:51.059568 2
3 3.663070 38.9808 0.00059597 2025-04-24 02:37:25.409892 3
4 3.578974 35.8368 0.00059371 2025-04-24 02:41:59.711315 4
5 3.510906 33.4786 0.00059096 2025-04-24 02:46:33.990839 5
6 3.462345 31.8917 0.00058771 2025-04-24 02:51:08.378331 6
7 3.412196 30.3318 0.00058399 2025-04-24 02:55:42.927474 7
8 3.376724 29.2747 0.00057978 2025-04-24 03:00:17.504665 8
9 3.336911 28.1321 0.00057511 2025-04-24 03:04:52.006063 9
10 3.313411 27.4787 0.00056999 2025-04-24 03:09:27.563336 10
11 3.295895 27.0016 0.00056441 2025-04-24 03:14:01.786079 11
```

RWKV-7 weight example for 1.5B (L24-D2048, vocab 65536):

**Make sure you only apply wd to large tensors (with "wdecay" in comment) here**, or the performance will be much worse.

| name                | shape         | comment      | initialization  |
|---------------------|---------------|--------------|-----------------|
| emb.weight          | [65536, 2048] | wdecay       | see code        |
| blocks.0.ln0.weight | [2048]        | for layer 0  | 1               |
| blocks.0.ln0.bias   | [2048]        | for layer 0  | 0               |
|                     |               |              |                 |
| blocks.*.ln1.weight | [2048]        |              | 1               |
| blocks.*.ln1.bias   | [2048]        |              | 0               |
| blocks.*.att.x_r    | [1, 1, 2048]  |              | see code        |
| blocks.*.att.x_w    | [1, 1, 2048]  |              | see code        |
| blocks.*.att.x_k    | [1, 1, 2048]  |              | see code        |
| blocks.*.att.x_v    | [1, 1, 2048]  |              | see code        |
| blocks.*.att.x_a    | [1, 1, 2048]  |              | see code        |
| blocks.*.att.x_g    | [1, 1, 2048]  |              | see code        |
| blocks.*.att.w0     | [1, 1, 2048]  | lr 2x        | see code        |
| blocks.*.att.w1     | [2048, 96]    |              | 0               |
| blocks.*.att.w2     | [96, 2048]    |              | see code        |
| blocks.*.att.a0     | [1, 1, 2048]  |              | 0               |
| blocks.*.att.a1     | [2048, 96]    |              | 0               |
| blocks.*.att.a2     | [96, 2048]    |              | see code        |
| blocks.*.att.v0     | [1, 1, 2048]  | for layer 1+ | 1               |
| blocks.*.att.v1                | [2048, 64]   | for layer 1+ | 0         |
| blocks.*.att.v2                | [64, 2048]   | for layer 1+ | see code  |
| blocks.*.att.g1                | [2048, 256]  |              | 0         |
| blocks.*.att.g2                | [256, 2048]  |              | see code  |
| blocks.*.att.k_k               | [1, 1, 2048] |              | 1         |
| blocks.*.att.k_a               | [1, 1, 2048] |              | 1         |
| blocks.*.att.r_k               | [32, 64]     |              | 0         |
| blocks.*.att.receptance.weight | [2048, 2048] | wdecay       | see code  |
| blocks.*.att.key.weight        | [2048, 2048] | wdecay       | see code  |
| blocks.*.att.value.weight      | [2048, 2048] | wdecay       | see code  |
| blocks.*.att.output.weight     | [2048, 2048] | wdecay       | 0         |
| blocks.*.att.ln_x.weight       | [2048]       |              | see code  |
| blocks.*.att.ln_x.bias         | [2048]       |              | 0         |
|                                |              |              |           |
| blocks.*.ln2.weight            | [2048]       |              | 1         |
| blocks.*.ln2.bias              | [2048]       |              | 0         |
| blocks.*.ffn.x_k               | [1, 1, 2048] |              | see code  |
| blocks.*.ffn.key.weight        | [8192, 2048] | wdecay       | see code  |
| blocks.*.ffn.value.weight      | [2048, 8192] | wdecay       | 0         |
|                                |              |              |           |
| ln_out.weight | [2048]        |        | 1         |
| ln_out.bias   | [2048]        |        | 0         |
| head.weight   | [65536, 2048] | wdecay | see code  |

Train RWKV-6: use /RWKV-v5/ and use --my_testing "x060" in demo-training-prepare.sh and demo-training-run.sh

Your loss curve should look almost exactly the same as this, with the same ups and downs (if you use the same bsz & config):

<img src="RWKV-v5-minipile.png" width="500">

You can run your model using https://pypi.org/project/rwkv/ (use "rwkv_vocab_v20230424" instead of "20B_tokenizer.json")

Use https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/make_data.py to prepare binidx data from jsonl, and compute "--my_exit_tokens" and "--magic_prime".

Use https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/compute_magic_prime.py to compute "--my_exit_tokens" and "--magic_prime" for existing binidx.

Much faster tokenizer of large data: https://github.com/cahya-wirawan/json2bin https://github.com/cahya-wirawan/rwkv-tokenizer https://github.com/m8than/RWKV-World-Tokenizer-CPP

The "epoch" in train.py is "mini-epoch" (not real epoch. only for convenience), and 1 mini-epoch = 40320 * ctx_len tokens.

For example, if your binidx has 1498226207 tokens and ctxlen=4096, set "--my_exit_tokens 1498226207" (this will override epoch_count), and it will be 1498226207/(40320 * 4096) = 9.07 miniepochs. The trainer will auto-exit after "--my_exit_tokens" tokens. Set "--magic_prime" to the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/4096-1 = 365776), which is "--magic_prime 365759" in this case.

simple: prepare SFT jsonl => repeat your SFT data 3 or 4 times in make_data.py. more repetition leads to overfitting.

advanced: repeat your SFT data 3 or 4 times in your jsonl (note make_data.py will shuffle all jsonl items) => add some base data (such as slimpajama) to your jsonl => and only repeat 1 times in make_data.py.

**Fix training spikes**: see the "Fixing RWKV-6 Spikes" part on this page. 

Or use RWKV-7 (much better). RWKV-7 is very stable and spike-free (verified for 0.1/0.4/1.5/2.9b):
<img src="RWKV-v7-loss.png" width="500">

**Simple inference for RWKV-6**: https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_v6_demo.py

**Simple inference for RWKV-5**: https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_v5_demo.py

**Note: In [state = kv + w * state] everything must be in fp32 because w can be very close to 1. So we can keep state and w in fp32, and convert kv to fp32.**

lm_eval: https://github.com/BlinkDL/ChatRWKV/blob/main/run_lm_eval.py

**Tips for small model / small data**: When I train RWKV music models, I use deep & narrow (such as L29-D512) dimensions, and apply wd and dropout (such as wd=2 dropout=0.02). Note RWKV-LM dropout is very effective - use 1/4 of your usual value.

## HOW TO TRAIN RWKV-7 on Pile (332G tokens) ##

See https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/demo-training-prepare-v7-pile.sh and https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/demo-training-run-v7-pile.sh

Get these files first:

pile_20B_tokenizer_text_document.bin (664230651068 bytes)

pile_20B_tokenizer_text_document.idx (4212099722 bytes)

### HOW TO FINETUNE RWKV-5 MODELS ###

Use .jsonl format for your data (see https://huggingface.co/BlinkDL/rwkv-5-world for formats).

Use https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/make_data.py to tokenizer it using World tokenizer into binidx, suitable for finetuning World models.

Rename the base checkpoint in your model folder to rwkv-init.pth, and change the training commands to use --n_layer 32 --n_embd 4096 --vocab_size 65536 --lr_init 1e-5 --lr_final 1e-5 for 7B.

0.1B = --n_layer 12 --n_embd 768 // 0.4B = --n_layer 24 --n_embd 1024 // 1.5B = --n_layer 24 --n_embd 2048 // 3B = --n_layer 32 --n_embd 2560 // 7B = --n_layer 32 --n_embd 4096

### State-tuning (tuning the initial state. zero inference overhead)

Currently unoptimized implementation, takes same vram as full SFT

```--train_type "states" --load_partial 1 --lr_init 1 --lr_final 0.01 --warmup_steps 10 (yes, use very high LR)```

use rwkv 0.8.26+ to auto-load the trained "time_state" 

### Initializing RWKV 5/6 Models ###

When you train RWKV from scratch, try my initialization for best performance. Check generate_init_weight() of src/model.py:
```
emb.weight => nn.init.uniform_(a=-1e-4, b=1e-4)
(Note ln0 of block0 is the layernorm for emb.weight)
head.weight => nn.init.orthogonal_(gain=0.5*sqrt(n_vocab / n_embd))

att.receptance.weight => nn.init.orthogonal_(gain=1)
att.key.weight => nn.init.orthogonal_(gain=0.1)
att.value.weight => nn.init.orthogonal_(gain=1)
att.gate.weight => nn.init.orthogonal_(gain=0.1)
att.output.weight => zero

att.ln_x.weight (groupnorm) => ((1 + layer_id) / total_layers) ** 0.7

ffn.key.weight => nn.init.orthogonal_(gain=1)
ffn.value.weight => zero
ffn.receptance.weight => zero
```
!!! If you are using positional embedding, maybe it's better to remove block.0.ln0 and use default initialization for emb.weight instead of my uniform_(a=-1e-4, b=1e-4) !!!

### Fixing RWKV-6 Spikes ###

0. upgrade to RWKV-7. It's very stable.

1. when training from scratch, add "k = k * torch.clamp(w, max=0).exp()" before "RUN_CUDA_RWKV6(r, k, v, w, u)", and remember to change your inference code too. you will see faster convergence.

2. use "--adam_eps 1e-18"

3. "--beta2 0.95" if you see spikes

4. in trainer.py do "lr = lr * (0.01 + 0.99 * trainer.global_step / w_step)" (originally 0.2 + 0.8), and "--warmup_steps 20"

5. "--weight_decay 0.1" leads to better final loss if you are training lots of data. set lr_final to 1/100 of lr_init when doing this.

### Misc

RWKV-7 can do math. See https://github.com/BlinkDL/RWKV-LM/blob/main/Research/rwkv7-g0-7.2b.md for details.

<img width="555" height="784" alt="image" src="https://github.com/user-attachments/assets/095b4576-962f-4274-ae1a-855406ec76c1" />

<img src="RWKV-v7-niah.png">

## Introducing RWKV

RWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly compute the hidden state for the "RNN" mode.

So it's combining the best of RNN and transformer - **great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding** (using the final hidden state).

**All latest RWKV weights:** https://huggingface.co/BlinkDL

**HF-compatible RWKV weights:** https://huggingface.co/RWKV

```python
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV                         # pip install rwkv
model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16')

out, state = model.forward([187, 510, 1563, 310, 247], None)   # use 20B_tokenizer.json
print(out.detach().cpu().numpy())                   # get logits
out, state = model.forward([187, 510], None)
out, state = model.forward([1563], state)           # RNN has state (use deepcopy if you want to clone it)
out, state = model.forward([310, 247], state)
print(out.detach().cpu().numpy())                   # same result as above
```

nanoRWKV: https://github.com/BlinkDL/nanoRWKV (does not require custom CUDA kernel to train, works for any GPU/CPU)

**Cool Community RWKV Projects**:

All (400+) RWKV projects: https://github.com/search?o=desc&q=rwkv&s=updated&type=Repositories

https://github.com/OpenGVLab/Vision-RWKV Vision RWKV

https://github.com/feizc/Diffusion-RWKV Diffusion RWKV

https://github.com/cgisky1980/ai00_rwkv_server Fastest WebGPU inference (nVidia/AMD/Intel)

https://github.com/cryscan/web-rwkv backend for ai00_rwkv_server

https://github.com/saharNooby/rwkv.cpp Fast CPU/cuBLAS/CLBlast inference: int4/int8/fp16/fp32

https://github.com/JL-er/RWKV-PEFT lora/pissa/Qlora/Qpissa/state tuning

https://github.com/RWKV/RWKV-infctx-trainer Infctx trainer

https://github.com/daquexian/faster-rwkv

https://github.com/mlc-ai/mlc-llm/pull/1275

https://github.com/TheRamU/Fay/blob/main/README_EN.md Digital Assistant with RWKV

https://github.com/harrisonvanderbyl/rwkv-cpp-cuda Fast GPU inference with cuda/amd/vulkan

**RWKV v6 in 250 lines** (with tokenizer too): https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_v6_demo.py

**RWKV v5 in 250 lines** (with tokenizer too): https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_v5_demo.py

**RWKV v4 in 150 lines** (model, inference, text generation): https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py

**RWKV v4 preprint** https://arxiv.org/abs/2305.13048

**RWKV v4 introduction, and in 100 lines of numpy**: https://johanwind.github.io/2023/03/23/rwkv_overview.html https://johanwind.github.io/2023/03/23/rwkv_details.html

![RWKV-7](RWKV-v7.png)

![MQAR](Research/RWKV-6-MQAR.png)

![RWKV-paper](RWKV-paper.png)

RWKV v6 illustrated:

![RWKV-v6](rwkv-x060.png)

![RWKV-v5-benchmark-1](RWKV-v5-benchmark-1.png)

A cool paper (Spiking Neural Network) using RWKV: https://github.com/ridgerchu/SpikeGPT

You are welcome to join the RWKV discord https://discord.gg/bDSBUMeFpc to build upon it. We have plenty of potential compute (A100 40Gs) now (thanks to Stability and EleutherAI), so if you have interesting ideas I can run them.

![RWKV-eval2](RWKV-eval2.png)

RWKV [loss vs token position] for 10000 ctx4k+ documents in Pile. RWKV 1B5-4k is mostly flat after ctx1500, but 3B-4k and 7B-4k and 14B-4k have some slopes, and they are getting better. This debunks the old view that RNNs cannot model long ctxlens. We can predict that RWKV 100B will be great, and RWKV 1T is probably all you need :)

![RWKV-ctxlen](RWKV-ctxlen.png)

ChatRWKV with RWKV 14B ctx8192:

![RWKV-chat](RWKV-chat.png)

I believe RNN is a better candidate for fundamental models, because: (1) It's more friendly for ASICs (no kv cache). (2) It's more friendly for RL. (3) When we write, our brain is more similar to RNN. (4) The universe is like an RNN too (because of locality). Transformers are non-local models.

RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M

GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M

Training speed: (new training code) RWKV-4 14B BF16 ctxlen4096 = 114K tokens/s on 8x8 A100 80G (ZERO2+CP). (old training code) RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G.

I am doing image experiments too (For example: https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder) and RWKV will be able to do txt2img diffusion :) My idea: 256x256 rgb image -> 32x32x13bit latents -> apply RWKV to compute transition probability for each of the 32x32 grid -> pretend the grids are independent and "diffuse" using these probabilities.

Smooth training - no loss spikes! (lr & bsz change around 15G tokens)
![RWKV-loss](RWKV-loss.png)

![RWKV-eval](RWKV-eval.png)

All of the trained models will be open-source. Inference is very fast (only matrix-vector multiplications, no matrix-matrix multiplications) even on CPUs, so you can even run a LLM on your phone.

How it works: RWKV gathers information to a number of channels, which are also decaying with different speeds as you move to the next token. It's very simple once you understand it.

**RWKV is parallelizable because the time-decay of each channel is data-independent (and trainable)**. For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect. Moreover, you can fine-tune RWKV into a non-parallelizable RNN (then you can use outputs of later layers of the previous token) if you want extra performance.

![RWKV-formula](RWKV-formula.png)

Here are some of my TODOs. Let's work together :)

* HuggingFace integration (check https://github.com/huggingface/transformers/issues/17230
), and optimized CPU & iOS & Android & WASM & WebGL inference. RWKV is a RNN and very friendly for edge devices. Let's make it possible to run a LLM on your phone. 

* Test it on bidirectional & MLM tasks, and image & audio & video tokens. I think RWKV can support Encoder-Decoder via this: for each decoder token, use a learned mixture of [decoder previous hidden state] & [encoder final hidden state]. Hence all decoder tokens will have access to the encoder output.

* Now training RWKV-4a with one single tiny extra attention (just a few extra lines comparing with RWKV-4) to further improve some difficult zeroshot tasks (such as LAMBADA) for smaller models. See https://github.com/BlinkDL/RWKV-LM/commit/a268cd2e40351ee31c30c5f8a5d1266d35b41829

User feedback:
> *I've so far toyed around the character-based model on our relatively small pre-training dataset (around 10GB of text), and the results are extremely good - similar ppl to models taking much, much longer to train.*

> *dear god rwkv is fast. i switched to another tab after starting training it from scratch & when i returned it was emitting plausible english & maori words, i left to go microwave some coffee & when i came back it was producing fully grammatically correct sentences.*

Tweet from Sepp Hochreiter (thank you!): https://twitter.com/HochreiterSepp/status/1524270961314484227

You can find me (BlinkDL) in the EleutherAI Discord too: https://www.eleuther.ai/get-involved/

![RWKV-demo](RWKV-demo.png)

## Quick start

**IMPORTANT: Use deepspeed==0.7.0 pytorch-lightning==1.9.5 torch==1.13.1+cu117 and cuda 11.7.1 or 11.7 (note torch2 + deepspeed has weird bugs and hurts model performance)**

Use https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo (latest code, compatible with v4).

Here is a great prompt for testing Q&A of LLMs. Works for any model: (found by minimizing ChatGPT ppls for RWKV 1.5B)
```python
prompt = f'\nQ & A\n\nQuestion:\n{qq}\n\nDetailed Expert Answer:\n' # let the model generate after this
```

### Inference

**Run RWKV-4 Pile models:** Download models from https://huggingface.co/BlinkDL. Set TOKEN_MODE = 'pile' in run.py and run it. It's fast even on CPU (the default mode).

**Colab for RWKV-4 Pile 1.5B**: https://colab.research.google.com/drive/1F7tZoPZaWJf1fsCmZ5tjw6sYHiFOYVWM

Run RWKV-4 Pile models in your browser (and onnx version): see this issue https://github.com/BlinkDL/RWKV-LM/issues/7

RWKV-4 Web Demo: https://josephrocca.github.io/rwkv-v4-web/demo/ (note: only greedy sampling for now)

For the old RWKV-2: see the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).

### Training / Fine-tuning

pip install deepspeed==0.7.0 // pip install pytorch-lightning==1.9.5 // torch 1.13.1+cu117

NOTE: add weight decay (0.1 or 0.01) and dropout (0.1 or 0.01) when training on small amt of data. try x=x+dropout(att(x)) x=x+dropout(ffn(x)) x=dropout(x+att(x)) x=dropout(x+ffn(x)) etc.

**Training RWKV-4 from scratch:** run train.py, which by default is using the enwik8 dataset (unzip https://data.deepai.org/enwik8.zip).

You will be training the "GPT" version because it's paralleziable and faster to train. RWKV-4 can extrapolate, so training with ctxLen 1024 can work for ctxLen of 2500+. You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens.

**Fine-tuning RWKV-4 Pile models:** use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into train.npy data. Then use https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/train.py to train it.

Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you should begin with .xx and .aa/.bb (.aa divided by .bb).

Colab for fine-tuning RWKV-4 Pile models: https://colab.research.google.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb

**Large corpus:** Use https://github.com/Abel2076/json2binidx_tool to convert .jsonl into .bin and .idx

The jsonl format sample (one line for each document):
```
{"text": "This is the first document."}
{"text": "Hello\nWorld"}
{"text": "1+1=2\n1+2=3\n2+2=4"}
```
generated by code like this:
```
ss = json.dumps({"text": text}, ensure_ascii=False)
out.write(ss + "\n")
```

**Infinite ctxlen training (WIP):** https://github.com/Blealtan/RWKV-LM-LoRA/tree/dev-infctx

### How to use RWKV hidden state as text embedding

Consider RWKV 14B. The state has 200 vectors, that is, 5 vectors for each block: fp16 (xx), fp32 (aa), fp32 (bb), fp32 (pp), fp16 (xx).

Do not avg pool because different vectors (xx aa bb pp xx) in the state have very different meanings and ranges. You can probably remove pp.

I suggest firstly collect the mean+stdev statistics of each channel of each vector, and normalize all of them (note: the normalization should be data-indepedent and collected from various texts). Then train a linear classifer.

## Towards RWKV-5 (just to record some new ideas)

### Lastest Design

RWKV-5 is multi-head and here shows one head. There is also a LayerNorm for each head (hence actually GroupNorm).

$`
\begin{array}{|l|l|l|}
\hline & \text { RWKV-4 with real-valued } k \,\&\, v \,\&\, u \,\&\, w & \text { RWKV-5 with matrix-valued } \mathrm{k}^{\dagger} \mathrm{v} \,\&\, \mathrm{u} \,\&\, \mathrm{w} \\
\hline \mathrm{y}_0 & \mathrm{r}_0 \frac{\mathrm{uk}_0 \mathrm{v}_0}{\mathrm{uk}_0} & \mathrm{r}_0\left(\mathrm{uk}_0^{\dagger} \mathrm{v}_0\right) \\
\hline \mathrm{y}_1 & \mathrm{r}_1 \frac{\mathrm{uk}_1 \mathrm{v}_1+\mathrm{k}_0 \mathrm{v}_0}{\mathrm{uk}_1+\mathrm{k}_0} & \mathrm{r}_1\left(\mathrm{uk}_1^{\dagger} \mathrm{v}_1+\mathrm{k}_0^{\dagger} \mathrm{v}_0\right) \\
\hline \mathrm{y}_2 & \mathrm{r}_2 \frac{\mathrm{uk}_2 \mathrm{v}_2+\mathrm{k}_1 \mathrm{v}_1+\mathrm{wk}_0 \mathrm{v}_0}{\mathrm{uk}_2+\mathrm{k}_1+\mathrm{wk}_0} & \mathrm{r}_2\left(\mathrm{uk}_2^{\dagger} \mathrm{v}_2+\mathrm{k}_1^{\dagger} \mathrm{v}_1+\mathrm{wk}_0^{\dagger} \mathrm{v}_0\right) \\
\hline \mathrm{y}_3 & \mathrm{r}_3 \frac{\mathrm{uk}_3 \mathrm{v}_3+\mathrm{k}_2 \mathrm{v}_2+\mathrm{wk}_1 \mathrm{v}_1+\mathrm{w}^2 \mathrm{k}_0 \mathrm{v}_0}{\mathrm{uk}_3+\mathrm{k}_2+\mathrm{wk}_1+\mathrm{w}^2 \mathrm{k}_0} & \mathrm{r}_3\left(\mathrm{uk}_3^{\dagger} \mathrm{v}_3+\mathrm{k}_2^{\dagger} \mathrm{v}_2+\mathrm{wk}_1^{\dagger} \mathrm{v}_1+\mathrm{w}^2 \mathrm{k}_0^{\dagger} \mathrm{v}_0\right) \\
\hline
\end{array}`$

$`\left[\begin{array}{ll}
\mathrm{y}_{20} & \cdots \mathrm{y}_{2 \mathrm{c}}
\end{array}\right]=\left[\begin{array}{lll}
\mathrm{r}_{20} & \cdots & \mathrm{r}_{2 \mathrm{c}}
\end{array}\right]`$
$`\left(\left[\begin{array}{ccc}
\mathrm{u}_{00} & \cdots & \mathrm{u}_{0 \mathrm{c}} \\
\vdots & \ddots & \vdots \\
\mathrm{u}_{\mathrm{c} 0} & \cdots & \mathrm{u}_{\mathrm{cc}}
\end{array}\right]\left[\begin{array}{ccc}
\mathrm{k}_{20} \mathrm{v}_{20} & \cdots & \mathrm{k}_{20} \mathrm{v}_{2 \mathrm{c}} \\
\vdots & \ddots & \vdots \\
\mathrm{k}_{2 \mathrm{c}} \mathrm{v}_{20} & \cdots & \mathrm{k}_{2 \mathrm{c}} \mathrm{v}_{2 \mathrm{c}}
\end{array}\right]+\left[\begin{array}{ccc}
\mathrm{k}_{10} \mathrm{v}_{10} & \cdots & \mathrm{k}_{10} \mathrm{v}_{1 \mathrm{c}} \\
\vdots & \ddots & \vdots \\
\mathrm{k}_{1 \mathrm{c}} \mathrm{v}_{10} & \cdots & \mathrm{k}_{1 \mathrm{c}} \mathrm{v}_{1 \mathrm{c}}
\end{array}\right]+\left[\begin{array}{ccc}
\mathrm{w}_{00} & \cdots & \mathrm{w}_{0 \mathrm{c}} \\
\vdots & \ddots & \vdots \\
\mathrm{w}_{\mathrm{c} 0} & \cdots & \mathrm{w}_{\mathrm{cc}}
\end{array}\right]\left[\begin{array}{ccc}
\mathrm{k}_{00} \mathrm{v}_{00} & \cdots & \mathrm{k}_{00} \mathrm{v}_{0 c} \\
\vdots & \ddots & \vdots \\
\mathrm{k}_{0 \mathrm{c}} \mathrm{v}_{00} & \cdots & \mathrm{k}_{0 \mathrm{c}} \mathrm{v}_{0 c}
\end{array}\right]
\right)`$

### RWKV-6

Dynamic Mix & Dynamic Decay. Example (do this for both TimeMix & ChannelMix):
```
TIME_MIX_EXTRA_DIM = 32
self.time_mix_k_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_k_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
self.time_mix_v_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_v_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
self.time_mix_r_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_r_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
self.time_mix_g_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_g_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
...
time_mix_k = self.time_mix_k.view(1,1,-1) + (x @ self.time_mix_k_w1) @ self.time_mix_k_w2
time_mix_v = self.time_mix_v.view(1,1,-1) + (x @ self.time_mix_v_w1) @ self.time_mix_v_w2
time_mix_r = self.time_mix_r.view(1,1,-1) + (x @ self.time_mix_r_w1) @ self.time_mix_r_w2
time_mix_g = self.time_mix_g.view(1,1,-1) + (x @ self.time_mix_g_w1) @ self.time_mix_g_w2

xx = self.time_shift(x)
xk = x * time_mix_k + xx * (1 - time_mix_k)
xv = x * time_mix_v + xx * (1 - time_mix_v)
xr = x * time_mix_r + xx * (1 - time_mix_r)
xg = x * time_mix_g + xx * (1 - time_mix_g)
```

![RWKV-v6](RWKV-v6.png)

### RWKV-7

Use parallelized mode to quickly generate the state, then use a finetuned full RNN (the layers of token n can use outputs of all layer of token n-1) for sequential generation.

### Some old ideas

1. Now time decay is like 0.999^T (0.999 is learnable). Change it to something like (0.999^T + 0.1) where 0.1 is learnable too. The 0.1 part will be kept forever. Or, A^T + B^T + C = fast-decay + slow-decay + constant. Can even use different formulas (for example, K^2 instead of e^K for a decay component, or, without normalization).

2. Use complex-valued decay (so, rotation instead of decay) in some channels.

3. Inject some trainable and extrapolatable positional encoding?

4. Aside from 2d rotation, we can try other Lie groups such as 3d rotation ( SO(3) ). Non-abelian RWKV lol.

5. RWKV might be great on analog devices (search for Analog Matrix-vector multiplication & Photonic Matrix-vector multiplication). The RNN mode is very hardware-friendly (processing-in-memory). Can be a SNN too (https://github.com/ridgerchu/SpikeGPT). I wonder if it can be optimized for quantum computation.

6. Trainable initial hidden state (xx aa bb pp xx).

7. Layerwise (or even row/column-wise, elementwise) LR, and test Lion optimizer.

### Vision Tasks

1. I find it's good to add a 2d pos encoding:
```
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
...
x = x + pos_emb_x + pos_emb_y
```

2. In a BPE langauge model, it's the best to use [tokenShift of 1 token] (you can mix more tokens in a char-level English model). However you can try [tokenShift of N (or N-1) (or N+1) tokens] if the image size is N x N, because that will be like mixing [the token above the current positon (or the token above the to-be-predicted positon)] with [current token]. You can use try different tokenShift styles for "ATT" & "FFN", or mixing different tokenShift styles - such as mixing [token A] with [token A-1] and [token A-(N-1)] etc.

### Misc

Maybe we can improve memorization by simply repeating the context (I guess 2 times is enough). Example:  Reference -> Reference(again) -> Question -> Answer

#### Idea: Bytes-aware Embedding

The idea is to make sure each token in vocab understand its length and raw UTF-8 bytes.

Let a = max(len(token)) for all token in vocab. Define AA : float[a][d_emb]

Let b = max(len_in_utf8_bytes(token)) for all token in vocab. Define BB : float[b][256][d_emb]

For each token X in vocab, let [x0, x1, ..., xn] be its raw UTF-8 bytes. We will add some extra values to its embedding EMB(X):

EMB(X) += AA[len(X)] + BB[0][x0] + BB[1][x1] + ... + BB[n][xn] (note: AA BB are learnable weights)

* We can do this for the final Linear(d_emb, n_vocab) projection too.
* We can use some small networks to generate AA and BB, for some extra regularization (for example, BB[m][xi] and BB[n][xi] should be related).

#### Old Idea

I have an idea to improve tokenization. We can hardcode some channels to have meanings. Example:

Channel 0 = "space"

Channel 1 = "capitalize first letter"

Channel 2 = "capitalize all letters"

Therefore:

Embedding of "abc":  [0, 0, 0, x0, x1, x2 , ..]

Embedding of " abc":  [1, 0, 0, x0, x1, x2, ..]

Embedding of " Abc":  [1, 1, 0, x0, x1, x2, ..]

Embedding of "ABC": [0, 0, 1, x0, x1, x2, ...]

......

so they will share most of the embedding. And we can rapidly compute the output probability of all variations of "abc".

Note: the above method is assuming that p(" xyz") / p("xyz") is the same for any "xyz", which can be wrong.

Better: define emb_space emb_capitalize_first emb_capitalize_all to be a function of emb.

Maybe the Best: let 'abc' ' abc' etc. to share the last 90% of their embeddings.

At this moment, all our tokenizers spend too many items to represent all variations of 'abc' ' abc' ' Abc' etc. Moreover the model cannot discover that these are actually similar if some of these variations are rare in the dataset. The method here can improve this. I plan to test this in a new version of RWKV.

#### Idea: Better Initial States

Example (single-round Q & A):

1. Generate the final state of all wiki documents.

2. For any user Q, find the best wiki document, and use its final state as the initial state.

3. Train a model to directly generate the optimal initial state for any user Q.

However this can be a bit more tricky for multi-round Q & A :)

## How it works

RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).

Moreover it's using a number of my tricks, such as:

* SmallInitEmb: https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers) which helps the embedding quality, and stabilizes Post-LN (which is what I am using).

* Token-shift: https://github.com/BlinkDL/RWKV-LM#token-shift-time-shift-mixing (applicable to all transformers), especially helpful for char-level models.

* Head-QK: https://github.com/BlinkDL/RWKV-LM#the-head-qk-trick-learning-to-copy-and-avoid-tokens (applicable to all transformers). Note: it's helpful, but I disabled it in the Pile model to keep it 100% RNN.

* Extra R-gate in the FFN (applicable to all transformers). I am also using reluSquared from Primer.

* Better initilization: I init most of the matrices to ZERO (see RWKV_Init in https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v2-RNN/src/model.py).

* You can transfer some parameters from a small model to a large model (note: I sort & smooth them too), for faster and better convergence (see https://www.reddit.com/r/MachineLearning/comments/umq908/r_rwkvv2rnn_a_parallelizable_rnn_with/).

* My CUDA kernel: https://github.com/BlinkDL/RWKV-CUDA to speedup training.

## The pseudocode (execution from top to bottom):

![RWKV-v2-RNN](RWKV-v2-RNN.png)

The a b c d factors work together to build a time-decay curve: [X, 1, W, W^2, W^3, ...].

Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea:
* a and b: EMAs of kv and k.
* c and d: these are a and b combined with "self-attention".

kv / k is the memory mechanism. The token with high k can be remembered for a long duration, if W is close to 1 in the channel.

The R-gate is important for performance. k = info strength of this token (to be passed to future tokens). r = whether to apply the info to this token.

## RWKV-3 improvements

Use different trainable TimeMix factors for R / K / V in SA and FF layers. Example:
```python
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
```

Use preLN instead of postLN (more stable & faster convergence):
```python
if self.layer_id == 0:
	x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
```

## Explaining the code for RWKV-3 GPT mode

### The GPT mode - overview

The building blocks of RWKV-3 GPT mode are similar to that of a usual preLN GPT.

The only difference is an extra LN after embedding. Note you can absorb this LN into the embedding after finishing the training.
```python
x = self.emb(idx)  # input: idx = token indices
x = self.ln_emb(x) # extra LN after embedding
x = x + self.att_0(self.ln_att_0(x)) # preLN
x = x + self.ffn_0(self.ln_ffn_0(x))
...
x = x + self.att_n(self.ln_att_n(x))
x = x + self.ffn_n(self.ln_ffn_n(x))
x = self.ln_head(x) # final LN before projection
x = self.head(x)    # output: x = logits
```
It is important to initialize emb to tiny values, such as nn.init.uniform_(a=-1e-4, b=1e-4), to utilize my trick https://github.com/BlinkDL/SmallInitEmb.

For the 1.5B RWKV-3, I use Adam (no wd, no dropout) optimizer on 8 * A100 40G.

batchSz = 32 * 896, ctxLen = 896. I am using tf32 so the batchSz is a bit small. 

For the first 15B tokens, LR is fixed at 3e-4, and beta=(0.9, 0.99).

Then I set beta=(0.9, 0.999), and do an exponential decay of LR, reaching 1e-5 at 332B tokens.

### The GPT mode - ATT block

The RWKV-3 does not have any attention in the usual sense, but we will call this block ATT anyway.
```python
B, T, C = x.size() # x = (Batch,Time,Channel)

# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v

# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(x.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)

# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
if RUN_DEVICE == 'cuda':
    wkv = TimeX.apply(w, kv, B,C,T, 0)
    wk = TimeX.apply(w, k, B,C,T, K_EPS)
else:
    w = w[:,-T:].unsqueeze(1)
    wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
    wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS

# The RWKV formula
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv) # final output projection
```

The self.key, self.receptance, self.output matrices are all initialized to zero.

The time_mix, time_decay, time_first vectors are transferred from a smaller trained model (note: I sort & smooth them too).

### The GPT mode - FFN block

The FFN block has three tricks comparing with the usual GPT:

1. My time_mix trick.

2. The sqReLU from the Primer paper.

3. An extra receptance-gate (similar to the receptance-gate in ATT block).
```python
# Mix x with the previous timestep to produce xk, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

# The usual FFN operation
k = self.key(xk)
k = torch.square(torch.relu(k)) # from the Primer paper
kv = self.value(k)

# Apply an extra receptance-gate to kv
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
```
The self.value, self.receptance matrices are all initialized to zero.

## RWKV-4 improvements

![RWKV-v3-plan](RWKV-v3-plan.png)

## From GPT to RWKV (the formulas)

Let F[t] be the system state at t.

Let x[t] be the new external input at t.

In GPT, predicting F[t+1] requires considering F[0], F[1], .. F[t]. So it takes O(T^2) to generate a length T sequence.

The **simplified formula** for GPT:

![F[\mathrm{t}+1]=\frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B%5Cmathrm%7Bt%7D%2B1%5D%3D%5Cfrac%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D)

It's very capable in theory, however that **does not mean we can fully utilize its capability with usual optimizers**. I suspect the loss landscape is too difficult for our current methods.

Compare with the **simplified formula** for RWKV (the parallel mode, looks similar to Apple's AFT):

![F[\mathrm{t}+1]=\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K }F[\mathrm{i}])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B%5Cmathrm%7Bt%7D%2B1%5D%3D%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot+%5Cfrac%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D)

The R, K, V are trainable matrices, and W is a trainable vector (time-decay factor for each channel).

In GPT, the contribution of F[i] to F[t+1] is weighted by ![ \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle++%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).

In RWKV-2, the contribution of F[i] to F[t+1] is weighted by ![\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).
* The ![\sigma](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma) is a non-linearity and we can use sigmoid. 
* Note ![\sigma(\mathbf{R}x[\mathrm{t}])](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29) is not in the denominator, and I call R the "receptance".
* The ![\exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i}))](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29) is the time-decay factor. I proposed the same idea (scaling the attention by distance) in Aug 2020 and called it the "time-weighting" (check the commit history of https://github.com/BlinkDL/minGPT-tuned).

Here comes the punchline: we can rewrite it into a RNN (recursive formula). Note:

![F[1]=\sigma(\mathbf{R }x[0]) \cdot \frac{ \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{\exp (\mathbf{K }F[0])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B1%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5B0%5D%29+%5Ccdot+%5Cfrac%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B0%5D%29%7D%7B%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29%7D)

![F[2]=\sigma(\mathbf{R }x[1]) \cdot \frac{ \exp (\mathbf{K }F[1]) \cdot(\mathbf{V }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{ \exp (\mathbf{K }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B2%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5B1%5D%29+%5Ccdot+%5Cfrac%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B1%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B1%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D+%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B0%5D%29%7D%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B1%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D+%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29%7D)

Therefore it's straightforward to verify:

![F[t+1]=\sigma(\mathbf{R }x[t]) \cdot \frac{\exp (\mathbf{K}F[\mathrm{t}]) \cdot(\mathbf{V}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot A[\mathrm{t}]}{ \exp (\mathbf{K}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot B[\mathrm{t}]}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5Bt%2B1%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5Bt%5D%29+%5Ccdot+%5Cfrac%7B%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bt%7D%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D%29+%5Ccdot+A%5B%5Cmathrm%7Bt%7D%5D%7D%7B+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bt%7D%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D%29+%5Ccdot+B%5B%5Cmathrm%7Bt%7D%5D%7D)

where A[t] and B[t] are the numerator and denominator of the previous step, respectively.

I believe RWKV is performant because W is like repeatedly applying a diagonal matrix. Note (P^{-1} D P)^n = P^{-1} D^n P, so it is similar to repeatedly applying a general diagonalizable matrix.

Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later.

## Star History

[![Star History Chart](https://api.star-history.com/svg?repos=BlinkDL/RWKV-LM&type=Date)](https://star-history.com/#BlinkDL/RWKV-LM&Date)

## Multimodal ideas

I have an idea for [text --> 32x32 RGB image] using a LM (transformer, RWKV, etc.). Will test it soon.

Firstly, LM loss (instead of L2 loss), so the image will not be blurry.

Secondly, color quantization. For example, only allowing 8 levels for R/G/B. Then the image vocab size is 8x8x8 = 512 (for each pixel), instead of 2^24.
Therefore, a 32x32 RGB image = a len1024 sequence of vocab512 (image tokens), which is a typical input for usual LMs.
(Later we can use diffusion models to upsample and generate RGB888 images. We might be able to use a LM for this too.)

Thirdly, 2D positional embeddings that are easy for the model to understand.
For example, add one-hot X & Y coords to the first 64(=32+32) channels. Say if the pixel is at x=8, y=20, then we will add 1 to channel 8 and channel 52 (=32+20).
Moreover probably we can add the float X & Y coords (normalized to 0~1 range) to another 2 channels. And other periodic pos. encoding might help too (will test). 

Finally, RandRound when doing the color quantization in the DataLoader.
For example, if the float level is 4.578, then there is a 57.8% chance to use 5, and (1-57.8%) chance to use 4.
And we can allow both 4 and 5 in the prediction, but the loss will be higher if the prediction is 4.

Multi-task training might help too. I will try this dataset format:
[TxtFirst] [Desc of Img (txt tokens)] [Img] [img tokens]
and sometimes
[ImgFirst] [img tokens] [Txt] [Desc of Img (txt tokens)]
... the order of the imgs should be randomized in the DataLoader, and [TxtFirst] [ImgFirst] [Img] [Txt] are special tokens
and do random sampling of the full dataset. So sometimes the model will see the img tokens first and then the corresponding txt tokens, which is a [img -> txt] task. And the model will see some partial imgs and partial txts. I think a char-level LM might help the model to write correct text on images.

## How to sample a large dataset (for training)

I am using a trick to sample the Pile deterministically yet randomly enough.

Let's say the pile has x chunks (a chunk = ctx_len tokens).

pick a prime number p just less than x, and make sure p = 2 (mod 3).

Use (step * step * step) mod p to sample it. Add some bias to step for extra randomness.

## The top-p-x sampling method (for inference)

We propose a new sampling method called top-p-x:

it's like top-p, and the only difference is you also keep all tokens whose prob > x.

Try x = 0.01 first.

## Better Learning Rate Schedule via Variantional Method of Loss Curve

I propose a simple new method to find better LR schedules. The method is cost-efficient and practical for large LMs. The takeaway is we can model the loss curve dynamics (phenomenology) w.r.t. the LR, and a nice closed-form LR curve can be directly computed from it using variantional method. Moreover we can predict the final loss with reasonable accuracy.

UPDATE: In "Conclusion 1.", use the best-fitting regime (ignore the initial steps where our approximations break down) to fit the parameters.

Try this: fixed lr for 1 hr, then exponential decay to 0.2 * lr in 12 hrs, and choose the t=[1hr, 13hr] segment.

In the last three plots, black = predicted loss curve of the new LR schedule, blue = original (unoptimized) real loss curve, orange = new LR schedule.

![better_lr_schedule](Research/better_lr_schedule.png)

# RWKV v1

We propose the RWKV language model, with alternating time-mix and channel-mix layers:

<img src=
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+%5Cbegin%7Balign%2A%7D%0A%5Ctext%7BTime-mix+%3A%7D+%26%26+%5Ctext%7BTM%7D_%7Bt%2Cc%7D+%26%26%3D%26%26%5Ctext%7Bsigmoid%7D%28%5Ctext%7BR%7D_%7Bt%2Cc%7D%29+%26%26%5Ccdot%26%26+%26%26%5Ctextstyle%5Csum_%7Bu%7D+%26%26%5Ctextbf%7BW%7D_%7Bt%2Cu%2Cc%7D+%26%26%5Ccdot%26%26+%5Ctext%7Bsoftmax%7D_t%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29+%26%26%5Ccdot%26%26+%5Ctext%7BV%7D_%7Bu%2Cc%7D%5C%5C%0A%5Ctext%7BChannel-mix+%3A%7D+%26%26+%5Ctext%7BCM%7D_%7Bt%2Cc%7D+%26%26%3D%26%26%5Ctext%7Bsigmoid%7D%28%5Ctext%7BR%7D_%7Bt%2Cc%7D%29+%26%26%5Ccdot%26%26+%26%26%5Ctextstyle%5Csum_d+%26%26%5Ctextbf%7BW%7D_%7Bc%2Cd%7D+%26%26%5Ccdot%26%26+%5Ctext%7Bgelu%7D%28%5Ctext%7BK%7D_%7Bt%2Cd%7D%29+%26%26%5Ccdot%26%26+%5Ctext%7BV%7D_%7Bt%2Cd%7D%0A%5Cend%7Balign%2A%7D%0A" 
alt="\begin{align*}
\text{Time-mix :} && \text{TM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_{u} &&\textbf{W}_{t,u,c} &&\cdot&& \text{softmax}_t(\text{K}_{u,c}) &&\cdot&& \text{V}_{u,c}\\
\text{Channel-mix :} && \text{CM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_d &&\textbf{W}_{c,d} &&\cdot&& \text{gelu}(\text{K}_{t,d}) &&\cdot&& \text{V}_{t,d}
\end{align*}
">

* The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.

* The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.

(1) We changed the normalization (denominator). For masked language models, we define:

<img src=
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+%5Ctext%7Bsoftmax%7D_t%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29+%3D+%5Cfrac%7B%5Cexp%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29%7D%7B%5Csum_%7Bv+%5Cleq+t%7D%5Cexp%28%5Ctext%7BK%7D_%7Bv%2Cc%7D%29%7D" 
alt="\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}">

**(UPDATE: We are using the original AFT normalization in v2)**
 
Initialize K and R matrices (and the output projection matrix) to ZERO for fast & stable convergence.
 
(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):

<img src=
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+W_%7Bt%2Cu%2Cc%7D%3Df_h%28t-u%29%5Ccdot+%5Calpha_h%28u%29+%5Ccdot+%5Cbeta_h%28t%29" 
alt="W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)">

Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.

**(UPDATE: We remove α β γ factors in v2-RNN and restrict W to be of a simple form and hence able to rewrite it as RNN)**

* The Channel-mix is similar to GeGLU (https://arxiv.org/abs/2002.05202) with an extra R factor. Initialize R and W matrices to ZERO for fast & stable convergence.

* Finally, we add extra token-shift (time-shift mixing) as in (https://github.com/BlinkDL/minGPT-tuned).

# Token-shift (time-shift mixing)

The token-shift explicitly uses (half the channels of this token) & (half the channels of prev token) to generate all vectors (QKV, RWKV, ...).

```
self.time_shift = nn.ZeroPad2d((0,0,1,-1))

x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
```

Dividing channels by 2 and shift-1 works great for char-level English and char-level Chinese LM.

However for BPE-level English LM, it's only effective if your embedding is large enough (at least 1024 - so the usual small L12-D768 model is not enough).

My theory on the effectiveness of token-shift:

When we train a GPT, the hidden representation of a token has to accomplish two different objects:

1. Predict the next token. Sometimes this is easy (obvious next token).

2. Collect all previous context info, so later tokens can use it. This is always hard.

The shifted channels can focus on (2), so we have good propagation of info. It's like some kind of residual connection, or a small RNN inside the transformer.

You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it. I also found you may want to use less mixing in higher layers.

p.s. There is a MHA_pro model in this repo with strong performance. Give it a try :)

# The Head-QK Trick: learning to copy and avoid tokens

In usual transformer, a small model has difficulty copying tokens (such as person names) in the context. We add extra Q & K to the final output such that the model can directly copy (or avoid) tokens in the context. Afterwards the model will teach itself NER (named entity recognition) if you look at the learned weights.
```
q = self.head_q(x)[:,:T,:] # projecting to 256-d
k = self.head_k(x)[:,:T,:] # projecting to 256-d
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()       
x = self.head(x) + c
```
Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob).

# The top-a sampling method

We also propose a new sampling method called top-a (as in src/utils.py):

(1) Find the max probability p_max after softmax.

(2) Remove all entries whose probability is lower than 0.2 * pow(p_max, 2). So it's adaptive, hence "top-a".

(3) Feel free to tune the 0.2 and 2 factor. Tune 0.2 first.

The idea of top-a:
1. If max_prob=0.9, then remove all tokens with prob < 0.162 (so, removing all alternatives)
2. If max_prob=0.5, then remove all tokens with prob < 0.05  (so, allowing more choices)
3. If max_prob=0.1, then remove all tokens with prob < 0.002 (so, allowing lots of possibilities)

```
probs = F.softmax(logits, dim=-1)

limit = torch.pow(torch.max(probs), 2) * 0.02
logits[probs < limit] = -float('Inf')
```

# Performance

Character-level loss on simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip

![RWKV-vs-MHA](RWKV-vs-MHA.png)

Gray: usual MHA+Rotary+GeGLU - performance not as good. 17.2M params.

Red: RWKV ("linear" attention) - VRAM friendly - quite faster when ctx window is long - good performance. 16.6M params.

Green: MHA+Rotary+GeGLU+Token_shift. 17.2M params.

Blue: MHA_pro (MHA with various tweaks & RWKV-type-FFN) - slow - needs more VRAM - good performance. 16.6M params.

```
@software{peng_bo_2021_5196578,
  author       = {PENG Bo},
  title        = {BlinkDL/RWKV-LM: 0.01},
  month        = aug,
  year         = 2021,
  publisher    = {Zenodo},
  version      = {0.01},
  doi          = {10.5281/zenodo.5196577},
  url          = {https://doi.org/10.5281/zenodo.5196577}
}
```

# Initialization

We use careful initialization for RWKV to get fast convergence - orthogonal matrices with proper scaling, and special time_w curves. Check model.py for details.

Some learned time_w examples:

![RWKV-time-w](RWKV-time-w.png)


================================================
FILE: RWKV-8.md
================================================
# Improving RNNs (RWKV-8 and beyond)

Here I will show a framework to improve current RNNs.

## 1. Larger State

This includes larger heads size, larger inner model, hybrid models, ...

For RNNs, larger state => better performance, but bsz will be limited. And I belive massive parallel prefilling+decoding (large bsz, multi-agent) is the future.

From my view, there is a ladder of states: scalar state => vector state => matrix state (most current RNNs) => tensor state (a few papers tried this) => function state (attention is actually here, because it's kernel regression) => functional state => functor state => higher functor state => ...

Can certainly go beyond linear algebra: group, lie group, differential geometry, function space, category and higher categories, ... and only limited by efficient hardware implementation.

Indeed, new hardware (analog, quantum, ...) can change space and time complexity of some items in the ladder, and we are very far from endgame.

Example of a pratical step. Tensor states can be efficient, if only used in some heads, such as the slowest-decaying head. Use the better sum_{i,j} a[i] b[j] s[i,j,n] instead of the common idea sum_{i,j} a[i] a[j] s[i,j,n], and a 64x64x64 state can be a good starting point.

Note RWKV-4 has particular small states, and good for improvements.

## 2. Smaller State

This includes various tricks: sparse state, structured state, shared state, compressed state, low-rank state, quantized state, ... which can be found in various shrink-kv-cache papers too.

From my view, we can consider 6 dimensions: B (bsz), T (ctxlen), H (head), N (headsz), L (layer), Q (bits).

RNN statesz = f(B,H,N,L,Q). Transformer statesz = f(B,T,H,N,L,Q).

Can apply any trick to any dimension. Good for bingo.

Example:

H + sparse: use a router to select head.

N + sparse: use a router to select state inside a head. Larger state, similar I/O.

L + share: just like how a few papers proposed sharing kv cache between layers.

L + sparse: no need to go through all layers for all tokens.

T + compress: such as, compressing tokens into super-tokens, and can use raw bytes without tokenizer. Or, different ctxlen in different layers, such as T T/2 T T/2, T T/2 T/4 etc, and can restrict this to the hybrid attention part too.

Plenty of possiblities for each X + Y comination, and good for NAS.

## 3. Mixed State

Mixing state between heads. Mixing state between layers. These are expensive (when doing bwd). Can do them periodically, or when neccesary. Can do them at readout (cheaper).

Mixing state of the last layer of token n, with the state of the first layer of token n+1. A depth-L model becomes a depth-2L model after a step of this, and still efficiently trainable.

## 4. Fancy State Evolution

Example: Let A = evolution matrix. Try exp(sA)-1, 1/(1-sA), etc. with trainable dynamic s.

Example: DeltaProduct, fancy inner optimizers, fancy inner models.

These are all beneficial, and the question is {depth-L1 model with fancy state evolution} vs {depth-L2 model with simple state evolution} where L2 > L1 and speed-matched.

### Conclusion: we have room for 100 architecture papers here.

There are a number of more advanced methods beyond these, which I am exploring for RWKV-8.


================================================
FILE: RWKV-v1/src/__init__.py
================================================


================================================
FILE: RWKV-v1/src/model.py
================================================
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)

########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################

def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
    for m in module.modules():
        if not isinstance(m, (nn.Linear, nn.Embedding)):
            continue
        with torch.no_grad():
            name = '[unknown weight]'
            for name, parameter in module.named_parameters(): # find the name of the weight
                if id(m.weight) == id(parameter):
                    break

            shape = m.weight.data.shape
            gain = 1.0  # positive: gain for orthogonal, negative: std for normal
            scale = 1.0 # extra scale for gain

            if isinstance(m, nn.Linear):
                if m.bias is not None:
                    m.bias.data.zero_()
                if shape[0] > shape[1]:
                    gain = math.sqrt(shape[0] / shape[1])
                if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
                    scale = config.rwkv_emb_scale

            if isinstance(m, nn.Embedding):
                gain = math.sqrt(max(shape[0], shape[1]))
                if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
                    scale = config.rwkv_emb_scale

            if hasattr(m, 'scale_init'):
                scale = m.scale_init

            print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)

            gain *= scale
            if gain == 0:
                nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
            elif gain > 0:
                nn.init.orthogonal_(m.weight, gain=gain)
            else:
                nn.init.normal_(m.weight, mean=0, std=-gain)

class RWKV_TimeMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        assert config.n_attn % config.n_head == 0
        self.layer_id = layer_id
        self.ctx_len = config.ctx_len
        self.n_head = config.n_head
        self.head_size = config.n_attn // config.n_head

        with torch.no_grad(): # initial time_w curves for better convergence
            ww = torch.ones(config.n_head, config.ctx_len)
            curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
            for h in range(config.n_head):
                if h < config.n_head - 1:
                    decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
                else:
                    decay_speed = 0.0
                ww[h] = torch.exp(curve * decay_speed)
                # print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
        self.time_w = nn.Parameter(ww)

        self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
        self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
        self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
                
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)
        self.receptance = nn.Linear(config.n_embd, config.n_attn)

        # if config.rwkv_tiny_attn > 0:
        #     self.tiny_att = RWKV_TinyAttn(config)

        self.output = nn.Linear(config.n_attn, config.n_embd)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()
        TT = self.ctx_len
        w = F.pad(self.time_w, (0, TT))
        w = torch.tile(w, [TT])
        w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
        w = w[:, :, TT-1:] # w is now a circulant matrix
        w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]

        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
        # if hasattr(self, 'tiny_att'):
        #     tiny_att = self.tiny_att(x, self.mask)

        k = self.key(x)
        v = self.value(x)
        r = self.receptance(x)

        k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
        k = torch.exp(k)
        sum_k = torch.cumsum(k, dim=1)

        kv = (k * v).view(B, T, self.n_head, self.head_size)

        wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)

        rwkv = torch.sigmoid(r) * wkv / sum_k

        rwkv = self.output(rwkv)
        # if hasattr(self, 'tiny_att'):
        #     rwkv += tiny_att

        return rwkv * self.time_gamma[:T, :]

class RWKV_ChannelMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        
        hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
        self.key = nn.Linear(config.n_embd, hidden_sz)
        self.value = nn.Linear(config.n_embd, hidden_sz)
        self.weight = nn.Linear(hidden_sz, config.n_embd)
        self.receptance = nn.Linear(config.n_embd, config.n_embd)

        self.receptance.scale_init = 0
        self.weight.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()
        
        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
        k = self.key(x)
        v = self.value(x)
        r = self.receptance(x)
        
        wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu

        rwkv = torch.sigmoid(r) * wkv

        return rwkv

class RWKV_TinyAttn(nn.Module): # extra tiny attention
    def __init__(self, config):
        super().__init__()
        self.d_attn = config.rwkv_tiny_attn
        self.n_head = config.rwkv_tiny_head
        self.head_size = self.d_attn // self.n_head

        self.qkv = nn.Linear(config.n_embd, self.d_attn * 3)
        self.out = nn.Linear(self.d_attn, config.n_embd)

    def forward(self, x, mask):
        B, T, C = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim = -1)

        if self.n_head > 1:
            q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)      # (B, T, C) -> (B, nh, T, hs)
            k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)      # (B, T, C) -> (B, nh, T, hs)
            v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)      # (B, T, C) -> (B, nh, T, hs)

        qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size))     # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        qk = qk.masked_fill(mask == 0, float('-inf'))
        qk = F.softmax(qk, dim = -1)
        qkv = qk @ v                                                           # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)

        if self.n_head > 1:
            qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1)              # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
       
        return self.out(qkv)

########################################################################################################
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
########################################################################################################

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_len=None):
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()
            self.sin_cached = emb.sin()
        return self.cos_cached, self.sin_cached

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), -1)

@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:]
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class MHA_rotary(nn.Module):
    def __init__(self, config, layer_id, time_shift = False):
        super().__init__()
        self.layer_id = layer_id
        assert config.n_attn % config.n_head == 0
        self.n_head = config.n_head
        self.ctx_len = config.ctx_len
        self.head_size = config.n_attn // config.n_head

        if time_shift:
            self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        self.query = nn.Linear(config.n_embd, config.n_attn)
        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)

        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
        
        self.rotary_ndims = int(self.head_size * 0.5)
        self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

        self.output = nn.Linear(config.n_attn, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()

        if hasattr(self, 'time_shift'):
            x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)

        q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)         # (B, T, C) -> (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)

        q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
        k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
        cos, sin = self.rotary_emb(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)                                     # rotary encoding
        q = torch.cat((q, query_pass), dim=-1)
        k = torch.cat((k, key_pass), dim=-1)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))                 # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))                     # causal mask
        att = F.softmax(att, dim = -1)                                                  # softmax

        x = att @ v                                                                     # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
        x = x.transpose(1, 2).contiguous().view(B, T, -1)                               # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)

        x = self.output(x)
        return x

class GeGLU(torch.nn.Module):
    def __init__(self, config, layer_id, time_shift = False):
        super().__init__()
        self.layer_id = layer_id

        if time_shift:
            self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        hidden_sz = 3 * config.n_ffn
        self.key = nn.Linear(config.n_embd, hidden_sz)
        self.value = nn.Linear(config.n_embd, hidden_sz)
        self.weight = nn.Linear(hidden_sz, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()
        if hasattr(self, 'time_shift'):
            x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
        
        k = self.key(x)
        v = self.value(x)        
        y = self.weight(F.gelu(k) * v)
        return y

########################################################################################################
# MHA_pro: with more tricks
########################################################################################################

class MHA_pro(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        assert config.n_attn % config.n_head == 0
        self.n_head = config.n_head
        self.ctx_len = config.ctx_len
        self.head_size = config.n_attn // config.n_head

        self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
        self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
        self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
        self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))

        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.query = nn.Linear(config.n_embd, config.n_attn)
        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)
        
        self.rotary_ndims = int(self.head_size * 0.5)
        self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

        self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False)  # talking heads

        self.output = nn.Linear(config.n_attn, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()
        TT = self.ctx_len
        w = F.pad(self.time_w, (0, TT))
        w = torch.tile(w, [TT])
        w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
        w = w[:, :, TT-1:] # w is now a circulant matrix
        w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]

        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)      # time-shift mixing
        q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)         # (B, T, C) -> (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)

        q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
        k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
        cos, sin = self.rotary_emb(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)                                     # rotary encoding
        q = torch.cat((q, query_pass), dim=-1)
        k = torch.cat((k, key_pass), dim=-1)  
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))                 # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))                     # causal mask
        att = F.softmax(att, dim = -1)                                                  # softmax
        att = att * w                                                                   # time-weighting
        att = self.head_mix(att)                                                        # talking heads

        x = att @ v                                                                     # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
        x = x.transpose(1, 2).contiguous().view(B, T, -1)                               # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)

        x = self.output(x) * self.time_gamma[:T, :]
        return x

########################################################################################################
# The GPT Model with our blocks
########################################################################################################

class RMSNorm(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.dd = d ** (-1. / 2)
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        x_normed = x / (norm_x * self.dd + 1e-12)
        return self.weight * x_normed

class FixedNorm(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.dd = d ** (-1. / 2)

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        x_normed = x / (norm_x * self.dd + 1e-12)
        return x_normed

########################################################################################################

class GPTConfig:
    def __init__(self, vocab_size, ctx_len, **kwargs):
        self.vocab_size = vocab_size
        self.ctx_len = ctx_len
        for k,v in kwargs.items():
            setattr(self, k, v)

class Block(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.config = config

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        if config.model_type == 'RWKV':
            # self.ln1 = FixedNorm(config.n_embd)
            # self.ln2 = FixedNorm(config.n_embd)
            self.attn = RWKV_TimeMix(config, layer_id)
            self.mlp = RWKV_ChannelMix(config, layer_id)

        elif config.model_type == 'MHA_rotary':
            self.attn = MHA_rotary(config, layer_id)
            self.mlp = GeGLU(config, layer_id)
        
        elif config.model_type == 'MHA_shift':
            self.attn = MHA_rotary(config, layer_id, time_shift=True)
            self.mlp = GeGLU(config, layer_id, time_shift=True)
        
        elif config.model_type == 'MHA_pro':
            self.attn = MHA_pro(config, layer_id)
            self.mlp = RWKV_ChannelMix(config, layer_id)

    def forward(self, x):

        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)

        self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])

        self.ln_f = nn.LayerNorm(config.n_embd)
        self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.head_q = nn.Linear(config.n_embd, 256)
        self.head_q.scale_init = 0.01
        self.head_k = nn.Linear(config.n_embd, 256)
        self.head_k.scale_init = 0.01
        self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))

        self.ctx_len = config.ctx_len

        if self.config.model_type == 'RWKV':
            RWKV_Init(self, config)
        else:
            self.apply(self._init_weights)

        logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))

    def get_ctx_len(self):
        return self.ctx_len

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.01)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def configure_optimizers(self, train_config):
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()

        whitelist_weight_modules = (nn.Linear, )
        blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn):
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
        return optimizer

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."

        x = self.tok_emb(idx)

        x = self.blocks(x)

        x = self.ln_f(x)

        q = self.head_q(x)[:,:T,:]
        k = self.head_k(x)[:,:T,:]
        c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
        c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
        c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()

        x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
        x = self.head(x) + c

        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))

        return x, loss


================================================
FILE: RWKV-v1/src/trainer.py
================================================
import math, sys, datetime
import logging
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__)

# print('logging to wandb... (comment it if you don\'t have wandb)')
# import wandb # comment this if you don't have wandb

class TrainerConfig:
    max_epochs = 10
    batch_size = 64
    learning_rate = 4e-4
    betas = (0.9, 0.99)
    eps = 1e-8
    grad_norm_clip = 1.0
    weight_decay = 0.01
    lr_decay = False # linear warmup followed by cosine decay
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper
    final_tokens = 260e9 # at which point do we reach lr_final
    epoch_save_frequency = 0
    epoch_save_path = 'trained-'
    num_workers = 0 # for DataLoader

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

class Trainer:

    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config
        self.avg_loss = -1
        self.steps = 0

        if 'wandb' in sys.modules:
            cfg = model.config
            for k in config.__dict__:
                setattr(cfg, k, config.__dict__[k]) # combine cfg
            wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)

        self.device = 'cpu'
        if torch.cuda.is_available(): # take over whatever gpus are on the system
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def get_run_name(self):
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        cfg = raw_model.config
        run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
        return run_name

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset
            loader = DataLoader(data, shuffle=True, pin_memory=True,
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

            pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
            
            for it, (x, y) in pbar:
                x = x.to(self.device) # place data on the correct device
                y = y.to(self.device)
                
                with torch.set_grad_enabled(is_train):
                    _, loss = model(x, y) # forward the model
                    loss = loss.mean()         # collapse all losses if they are scattered on multiple gpus

                if is_train: # backprop and update the parameters                    
                    model.zero_grad()
                    loss.backward()

                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()
                    
                    if config.lr_decay: # decay the learning rate based on our progress
                        self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
                        lr_final_factor = config.lr_final / config.learning_rate
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = lr_final_factor + (1 - lr_final_factor) * float(self.tokens) / float(config.warmup_tokens)
                            progress = 0
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            # progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR
                            lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    now_loss = loss.item() # report progress
                    
                    if 'wandb' in sys.modules:
                        wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size)
                    self.steps += 1

                    if self.avg_loss < 0:
                        self.avg_loss = now_loss
                    else:
                        # factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
                        factor = 1 / (it + 1)
                        self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
                    pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")

        while True:
            self.tokens = 0 # counter used for learning rate decay
            for epoch in range(config.max_epochs):

                run_epoch('train')
                
                if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
                    raw_model = self.model.module if hasattr(self.model, "module") else self.model # DataParallel wrappers keep raw model object in .module
                    torch.save(raw_model, self.config.epoch_save_path + str(epoch+1) + '.pth')


================================================
FILE: RWKV-v1/src/utils.py
================================================
import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

def top_p_probs(probs, p):
    out = probs.clone()

    sorted_probs, sorted_indices = torch.sort(out, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0    
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    out[indices_to_remove] = 0

    return out

# top-p + top-k + pow&ratio sampling
def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None):
    logits = logits[:, pos, :] / temperature
    probs = F.softmax(logits, dim=-1)
    
    if min_p_ratio is not None:
        limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
        logits[probs < limit] = -float('Inf')
    
    if top_k is not None:
        logits = top_k_logits(logits, top_k)
    
    probs = F.softmax(logits, dim=-1)
    
    if top_p is not None:
        probs[0] = top_p_probs(probs[0], top_p)
    
    ix = torch.multinomial(probs, num_samples=1)
    return ix[0][0].cpu()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


================================================
FILE: RWKV-v1/train.py
================================================
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import os, sys, time, math, random, json, datetime, logging
import numpy as np
import torch
from torch.utils.data import Dataset
from src.trainer import Trainer, TrainerConfig
from src.model import GPT, GPTConfig
from src.utils import set_seed

set_seed(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)

# RWKV       : our new model - fastest when ctx_len is long - VRAM friendly - good performance
# MHA_rotary : usual MultiheadAttention+Rotary+GeGLU - not as good
# MHA_shift  : with time-shift - good performance
# MHA_pro    : slow (lots of tricks) - VRAM hungry - very good performance
model_type = 'RWKV'

# datafile = u"V:\\NLP\\text8"
# datafile = u"V:\\NLP\\enwik8"
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt"
datafile_encoding = 'utf-8'
# datafile = u"D:\\NLP-Data\\ww100M.txt"
# datafile = u"D:\\NLP-Data\\__2019.txt"
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile = u"V:\\NLP\\enwik8-shift-300.bpe"
# datafile_encoding = 'utf-16'
# datafile = u"V:\\NLP\\simplebooks-shift-utf32.word"
# datafile_encoding = 'utf-32'

datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs 

#################################### VERY IMPORTANT ####################################
epoch_save_frequency = 10                            # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
epoch_save_path = 'trained-'

batch_size = 32                                      # if you see "CUDA out of memory", reduce this.
                                                     # if you have good GPU, increase this.
                                                     # use GPU-Z to find the highest value for your VRAM.

n_epoch = 100                                        # the 'epoch' here is actually very short (and of fixed length)
########################################################################################

model_level = 'character' # 'character' (recommended) or 'word'

ctx_len = 256 # context length, try 512 or 1024 if you have good GPU
n_layer = 6   # try 12 for 100M, 24 for 300M
n_head = 8    # try 12 for 100M, 16 for 300M

n_embd = n_head * 64
n_attn = n_embd
n_ffn = n_embd

lr_init = 6e-4 if model_type == 'RWKV' else 4e-4    # RWKV can use higher lr.  8e-4 = 0.0008   4e-4 = 0.0004
lr_final = 4e-5

betas = (0.9, 0.99) if model_type == 'RWKV' else (0.9, 0.99)
eps = 4e-9
weight_decay = 0 if model_type == 'RWKV' else 0.01  # wd is not useful when we have enough data

epoch_length_fixed = 10000                          # make an 'epoch' very short, so we can see the training progress

######## special hyperparameters for RWKV model ########
rwkv_emb_scale = 0.4                                # scale of initial embedding. 0.4 is a good choice
rwkv_tiny_attn = 0#64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_head = 1                                  # 1 is good enough. 8 is slow
# n_side_proj = 512                                 # extra 'side projection', quite useful for BPE models 

########################################################################################################
# Load data
########################################################################################################

print('loading data... ' + datafile)

class Dataset(Dataset):
    def __init__(self, data, model_level, ctx_len):
        print('building token list...', end=' ')
        if model_level == 'word':
            import re
            data = re.sub(r'(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)', r' \g<0> ', data)
            data = re.sub(' +',' ',data)
            print('splitting token...')
            data = data.lower().split(' ')
        unique = sorted(list(set(data)))
        # print()
        # for u in unique:
        #     print(u, end=' ')
        # print('\n\n')

        xx = 0
        xxObj = {}
        for u in unique:
            xxObj[xx] = u
            xx += 1
        with open('vocab.json', "w", encoding="utf-16") as vocab_file:
            vocab_file.write(json.dumps(xxObj, ensure_ascii=False))

        data_size, vocab_size = len(data), len(unique)
        print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
        self.stoi = { ch:i for i,ch in enumerate(unique) }
        self.itos = { i:ch for i,ch in enumerate(unique) }
        self.ctx_len = ctx_len
        self.vocab_size = vocab_size
        self.data = data

    def __len__(self):
        return epoch_length_fixed

    def __getitem__(self, idx):
        i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # cheat: pick a random spot in dataset
        chunk = self.data[i:i+self.ctx_len+1]
        dix = [self.stoi[s] for s in chunk]
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_len)

########################################################################################################
# Train model
########################################################################################################

model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
                rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
                n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))

# load a trained model
# model.load_state_dict(torch.load('trained-xxx.pth').state_dict())

print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
                        learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
                        warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)

trainer.train()

torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')


================================================
FILE: RWKV-v2-RNN/cuda/timex_cuda.cu
================================================
#include <stdio.h>

// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)

#define F4(A, B) ((float4 *)(A))[(B) >> 2]

template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
                               const F eps, const int B, const int C, const int T) {
    const int i = blockIdx.y;
    const int ij = (B * C) / BF;
    const int t = threadIdx.x << 2;

    __shared__ F ww[Tmax];
    __shared__ F kk[Tmax * BF];
    F4(ww, t) = F4(__w, t + T * (i % C));
    
    #pragma unroll
    for (int j = 0; j < BF; j++) {
        F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
    }
    __syncthreads();

    float4 s[BF];
    #pragma unroll
    for (int j = 0; j < BF; j++) {
        s[j] = {eps, eps, eps, eps};
    }
    const F *__restrict__ const w = ww + T - t - 4;
    for (int u = 0; u <= t; u++) {
        #pragma unroll
        for (int j = 0; j < BF; j++) {
            const F x = kk[u + Tmax * j];
            s[j].x += w[u + 3] * x;
            s[j].y += w[u + 2] * x;
            s[j].z += w[u + 1] * x;
            s[j].w += w[u + 0] * x;
        }
    }
    #pragma unroll
    for (int j = 0; j < BF; j++) {
        const F *__restrict__ const k = kk + Tmax * j;
        s[j].y += w[t + 3] * k[t + 1];
        s[j].z += w[t + 2] * k[t + 1];
        s[j].z += w[t + 3] * k[t + 2];
        s[j].w += w[t + 1] * k[t + 1];
        s[j].w += w[t + 2] * k[t + 2];
        s[j].w += w[t + 3] * k[t + 3];
        F4(x, t + T * (i + ij * j)) = s[j];
    }
}

template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
                                F *__restrict__ const gw, F *__restrict__ const gk,
                                const int B, const int C, const int T) {
    const int i = blockIdx.y;
    const int t = threadIdx.x << 2;

    __shared__ F k[Tmax];
    __shared__ F gg[Tmax];
    F4(k, t) = F4(__k, t + T * i);
    F4(gg, t) = F4(__gwk, t + T * i);
    __syncthreads();

    float4 s = {0, 0, 0, 0};

    const F *__restrict__ const g = gg + T - t - 4;
    for (int u = 0; u <= t; u++) {
        F x = k[u];
        s.x += g[u + 3] * x;
        s.y += g[u + 2] * x;
        s.z += g[u + 1] * x;
        s.w += g[u + 0] * x;
    }
    s.y += g[t + 3] * k[t + 1];
    s.z += g[t + 2] * k[t + 1];
    s.z += g[t + 3] * k[t + 2];
    s.w += g[t + 1] * k[t + 1];
    s.w += g[t + 2] * k[t + 2];
    s.w += g[t + 3] * k[t + 3];
    F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
    dim3 gridDim(1, B * C / BF);
    dim3 blockDim(T >> 2);
    kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}

template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
                                F *__restrict__ const gw, F *__restrict__ const gk,
                                const int B, const int C, const int T) {
    const int i = blockIdx.y;
    const int ij = (B * C) / BB;
    const int t = threadIdx.x << 2;

    __shared__ F w[Tmax];
    __shared__ F kk[Tmax * BB];
    __shared__ F gg[Tmax * BB];
    F4(w, t) = F4(__w, t + T * (i % C));

    #pragma unroll
    for (int j = 0; j < BB; j++) {
        F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
        F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
    }
    __syncthreads();

    float4 s[BB];
    #pragma unroll
    for (int j = 0; j < BB; j++) {
        s[j] = {0, 0, 0, 0};
    }

    for (int u = 0; u <= t; u++) {
        #pragma unroll
        for (int j = 0; j < BB; j++) {
            const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
            F x = kk[u + Tmax * j];
            s[j].x += g[u + 3] * x;
            s[j].y += g[u + 2] * x;
            s[j].z += g[u + 1] * x;
            s[j].w += g[u + 0] * x;
        }
    }
    #pragma unroll
    for (int j = 0; j < BB; j++) {
        const F *__restrict__ const k = kk + Tmax * j;
        const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
        s[j].y += g[t + 3] * k[t + 1];
        s[j].z += g[t + 2] * k[t + 1];
        s[j].z += g[t + 3] * k[t + 2];
        s[j].w += g[t + 1] * k[t + 1];
        s[j].w += g[t + 2] * k[t + 2];
        s[j].w += g[t + 3] * k[t + 3];
        F4(gw, t + T * (i + ij * j)) = s[j];
    }

    #pragma unroll
    for (int j = 0; j < BB; j++) {
        s[j] = {0, 0, 0, 0};
    }

    for (int u = t + 3; u < T; u++) {
        F x = w[u];
        #pragma unroll
        for (int j = 0; j < BB; j++) {
            const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
            s[j].x += g[2 - u] * x;
            s[j].y += g[3 - u] * x;
            s[j].z += g[4 - u] * x;
            s[j].w += g[5 - u] * x;
        }        
    }
    #pragma unroll
    for (int j = 0; j < BB; j++) {
        const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
        s[j].x += g[2 - t] * w[t + 0];
        s[j].x += g[1 - t] * w[t + 1];
        s[j].x += g[0 - t] * w[t + 2];
        s[j].y += g[2 - t] * w[t + 1];
        s[j].y += g[1 - t] * w[t + 2];
        s[j].z += g[2 - t] * w[t + 2];
        F4(gk, t + T * (i + ij * j)) = s[j];
    }
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
    dim3 gridDim(1, B * C / BB);
    dim3 blockDim(T >> 2);
    kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}


================================================
FILE: RWKV-v2-RNN/cuda/timex_op.cpp
================================================
#include <torch/extension.h>

void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);

void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
    cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
    cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "timex forward");
    m.def("backward", &backward, "timex backward");
}

TORCH_LIBRARY(timex, m) {
    m.def("forward", forward);
    m.def("backward", backward);
}


================================================
FILE: RWKV-v2-RNN/run.py
================================================
# -*- coding:utf-8 -*-
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)

### Step 1: set model ##################################################################################

ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV'           # 'RWKV' or 'RWKV-ffnPre'

# your trained model
MODEL_NAME = 'trained-31'
WORD_NAME = 'vocab'           # the .json vocab (generated by train.py

# ########## Uncomment these to test my 27M params enwik8 model ##########
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
# WORD_NAME = 'enwik8-vocab'
# EVAL_DATA = 'enwik8'  # uncomment this for EVAL MODE (no text generation)
# ########################################################################

# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' '   # here we just set it to [space] for simplicity

RUN_DEVICE = 'cpu'   # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False  # True False - show softmax output

### Step 2: set context ################################################################################

context = "\nIn the"       # ==> this is your prompt

NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500

TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9

########################################################################################################

print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)

########################################################################################################

if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
    print('Evaluating on ' + EVAL_DATA + ' ...')

    data = open(EVAL_DATA, "r", encoding='utf-8').read()

    loss_table = np.zeros(ctx_len)

    N_SAMPLE = 1000

    for iii in range(N_SAMPLE):
        pos = np.random.randint(0, len(data) - ctx_len-1)
        context = data[pos:pos+ctx_len+1]
        ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]

        model.clear()
        for i in range(1, ctx_len+1):
            x = ctx[:i]
            out = model.run(x)
            prob = F.softmax(torch.tensor(out), dim=-1)
            loss_table[i-1] += -math.log(prob[ctx[i]])

        print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
              np.mean(loss_table) / (iii+1))

    exit(0)

########################################################################################################

context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')

for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
    t_begin = time.time_ns()

    src_len = len(context)
    ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
    print(('-' * 30) + context, end='')

    model.clear()
    if TRIAL == 0:
        init_state = types.SimpleNamespace()
        for i in range(src_len):
            x = ctx[:i+1]
            if i == src_len - 1:
                init_state.out = model.run(x)
            else:
                model.run(x)
        model.save(init_state)
    else:
        model.load(init_state)

    for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
        x = ctx[:i+1]
        x = x[-ctx_len:]

        if i == src_len:
            out = copy.deepcopy(init_state.out)
        else:
            out = model.run(x)
        if DEBUG_DEBUG:
            print('model', np.array(x), '==>', np.array(
                out), np.max(out), np.min(out))

        char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
                                       top_p_usual=top_p, top_p_newline=top_p_newline)
        char = char.item()
        print(tokenizer.itos[int(char)], end='', flush=True)
        ctx += [char]
    t_end = time.time_ns()
    print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')


================================================
FILE: RWKV-v2-RNN/src/model.py
================================================
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)

########################################################################################################
# CUDA Kernel
########################################################################################################

T_MAX = 1024          # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 4   # set to 8 for best performance
B_GROUP_BACKWARD = 2  # set to 2 for best performance

timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
                  verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])


class TimeX(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, k, B, C, T, eps):
        ctx.B = B
        ctx.C = C
        ctx.T = T
        assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
        w = w.contiguous()
        k = k.contiguous()
        ctx.save_for_backward(w, k)
        wk = torch.empty((B, C, T), device='cuda',
                         memory_format=torch.contiguous_format)
        timex_cuda.forward(w, k, wk, eps, B, C, T)
        return wk

    @staticmethod
    def backward(ctx, gwk):
        assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
        w, k = ctx.saved_tensors
        gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
                         memory_format=torch.contiguous_format)
        gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
                         memory_format=torch.contiguous_format)
        timex_cuda.backward(w, k, gwk.contiguous(), gw,
                            gk, ctx.B, ctx.C, ctx.T)
        return (gw.sum(dim=0), gk, None, None, None, None)

########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################


RWKV_K_CLAMP = 60  # e^60 = 1e26
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256


def RWKV_Init(module, config):  # fancy initialization of all lin & emb layer in the module
    for m in module.modules():
        if not isinstance(m, (nn.Linear, nn.Embedding)):
            continue
        with torch.no_grad():
            name = '[unknown weight]'
            for name, parameter in module.named_parameters():  # find the name of the weight
                if id(m.weight) == id(parameter):
                    break

            shape = m.weight.data.shape
            gain = 1.0
            scale = 1.0  # extra scale for gain

            if isinstance(m, nn.Embedding):
                gain = math.sqrt(max(shape[0], shape[1]))
                if shape[0] == config.vocab_size and shape[1] == config.n_embd:  # token emb?
                    scale = 1e-4
                else:
                    scale = 0

            if isinstance(m, nn.Linear):
                if m.bias is not None:
                    m.bias.data.zero_()
                if shape[0] > shape[1]:
                    gain = math.sqrt(shape[0] / shape[1])
                if shape[0] == config.vocab_size and shape[1] == config.n_embd:  # final projection?
                    scale = 0.5

            if hasattr(m, 'scale_init'):
                scale = m.scale_init

            # print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)

            gain *= scale
            if scale == -999:
                nn.init.eye_(m.weight)
            elif gain == 0:
                # zero init is great for some RWKV matrices
                nn.init.zeros_(m.weight)
            elif gain > 0:
                nn.init.orthogonal_(m.weight, gain=gain)
            else:
                nn.init.normal_(m.weight, mean=0.0, std=-scale)


class RWKV_TimeMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.ctx_len = config.ctx_len
        self.n_embd = config.n_embd

        attn_sz = config.n_embd

        ############# fancy init of time_w curves ###################################
        f1_begin = 3.0
        f1_end = 1.2
        f2_begin = 0.65
        f2_end = 0.4
        with torch.no_grad():  # initial time_w curves for better convergence
            decay_speed = torch.ones(attn_sz, 1)
            first_sa_layer_id = 1
            for h in range(attn_sz):
                f1 = f1_begin + (layer_id-first_sa_layer_id) / \
                    (config.n_layer-1-first_sa_layer_id) * (f1_end - f1_begin)
                f2 = f2_begin + (layer_id-first_sa_layer_id) / \
                    (config.n_layer-1-first_sa_layer_id) * (f2_end - f2_begin)
                if layer_id == first_sa_layer_id:
                    f1 += 0.5
                if layer_id == config.n_layer-2:
                    f2 = 0.4
                if layer_id == config.n_layer-1:
                    f2 = 0.37
                decay_speed[h][0] = math.pow(f2, h / (attn_sz-1) * 7) * f1
        self.time_decay = nn.Parameter(torch.log(decay_speed)) # will use exp(self.time_decay) to ensure time_decay > 0
        self.time_curve = torch.tensor(
            [-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
        self.time_curve = self.time_curve.to('cuda')
        self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3))
        #############################################################################

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        with torch.no_grad():  # init to "shift half of the channels"
            ww = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd // 2):
                ww[0, 0, i] = 0
        self.time_mix = nn.Parameter(ww)

        self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)

        self.output = nn.Linear(attn_sz, config.n_embd, bias=False)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()

        x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)

        k = self.key(x).transpose(-1, -2)
        v = self.value(x).transpose(-1, -2)
        r = self.receptance(x)

        # RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
        k = torch.clamp(k, max=RWKV_K_CLAMP)
        k = torch.exp(k)
        kv = k * v

        self.time_w = torch.cat(
            [torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
        w = torch.exp(self.time_w)

        wkv = TimeX.apply(w, kv, B, C, T, 0)
        # RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
        wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)

        rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
        rwkv = self.output(rwkv)
        return rwkv


class RWKV_ChannelMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

        with torch.no_grad():  # init to "shift half of the channels"
            x = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd // 2):
                x[0, 0, i] = 0
        self.time_mix = nn.Parameter(x)

        hidden_sz = 4 * config.n_embd
        self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)

        self.value.scale_init = 0
        self.receptance.scale_init = 0

    def forward(self, x):
        x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)

        k = self.key(x)
        k = torch.square(torch.relu(k))
        kv = self.value(k)

        rkv = torch.sigmoid(self.receptance(x)) * kv
        return rkv

########################################################################################################
# The GPT Model with our blocks
########################################################################################################


class GPTConfig:
    def __init__(self, vocab_size, ctx_len, **kwargs):
        self.vocab_size = vocab_size
        self.ctx_len = ctx_len
        for k, v in kwargs.items():
            setattr(self, k, v)


class Block(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.config = config
        self.layer_id = layer_id

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
            self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
        else:
            self.att = RWKV_TimeMix(config, layer_id)

        self.ffn = RWKV_ChannelMix(config, layer_id)

    def forward(self, x):
        x = self.ln1(x)
        if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
            x = x + self.ffnPre(x)  # better in some cases
        else:
            x = x + self.att(x)
        x = self.ln2(x)
        x = x + self.ffn(x)
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.step = 0
        self.config = config

        self.emb = nn.Embedding(config.vocab_size, config.n_embd)

        self.blocks = nn.Sequential(*[Block(config, i)
                                    for i in range(config.n_layer)])

        self.ln_out = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
        self.head_q.scale_init = 0
        self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
        self.head_k.scale_init = 0.1
        self.register_buffer("copy_mask", torch.tril(
            torch.ones(config.ctx_len, config.ctx_len)))

        self.ctx_len = config.ctx_len

        RWKV_Init(self, config)

        logger.info("number of parameters: %e", sum(p.numel()
                    for p in self.parameters()))

    def get_ctx_len(self):
        return self.ctx_len

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear)):
            module.weight.data.normal_(mean=0.0, std=0.01)
        if isinstance(module, (nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=1e-5)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def configure_optimizers(self, train_config):
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()

        for mn, m in self.named_modules():  # here we disable weight_decay
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
                no_decay.add(fpn)

        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(
            inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
            % (str(param_dict.keys() - union_params), )

        optim_groups = [
            {"params": [param_dict[pn]
                        for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]

        optimizer = torch.optim.Adam(
            optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)

        return optimizer

    def forward(self, idx, targets=None):
        self.step += 1
        B, T = idx.size()
        assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
        x = self.emb(idx)

        x = self.blocks(x)

        x = self.ln_out(x)

        q = self.head_q(x)[:, :T, :]
        k = self.head_k(x)[:, :T, :]
        c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
        c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)

        c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
        x = self.head(x) + c

        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))

        return x, loss


================================================
FILE: RWKV-v2-RNN/src/model_run.py
================================================
import types
import copy
import torch
from torch.nn import functional as F

RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256

DEBUG_TIME = False   # True False - show trained time-coeffs


class RWKV_RNN():
    def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
        self.RUN_DEVICE = RUN_DEVICE
        self.model_type = model_type
        self.n_layer = n_layer
        self.n_embd = n_embd
        self.ctx_len = ctx_len

        self.w = types.SimpleNamespace()

        w = torch.load(MODEL_NAME + '.pth',
                       map_location=torch.device(RUN_DEVICE))
        for x in w.keys():
            if '.time_' in x:
                w[x] = w[x].squeeze()
            if '.time_decay' in x:
                w[x] = torch.exp(-torch.exp(w[x]))
            if '.time_first' in x:
                w[x] = torch.exp(w[x])
            if DEBUG_TIME and '.time_' in x:
                print(x, w[x].squeeze().cpu().numpy())

            xx = x.split('.')
            here = self.w
            for i in range(len(xx)):
                if xx[i].isdigit():
                    ii = int(xx[i])
                    if ii not in here:
                        here[ii] = types.SimpleNamespace()
                    here = here[ii]
                else:
                    if i == len(xx) - 1:
                        setattr(here, xx[i], w[x])
                    elif not hasattr(here, xx[i]):
                        if xx[i+1].isdigit():
                            setattr(here, xx[i], {})
                        else:
                            setattr(here, xx[i], types.SimpleNamespace())
                    here = getattr(here, xx[i])

        self.clear()

    def clear(self):
        self.xx = {}
        self.aa = {}
        self.bb = {}
        self.hk = None

    def save(self, target):
        target.xx = copy.deepcopy(self.xx)
        target.aa = copy.deepcopy(self.aa)
        target.bb = copy.deepcopy(self.bb)
        target.hk = copy.deepcopy(self.hk)

    def load(self, target):
        self.xx = copy.deepcopy(target.xx)
        self.aa = copy.deepcopy(target.aa)
        self.bb = copy.deepcopy(target.bb)
        self.hk = copy.deepcopy(target.hk)

    def LN(self, xx, w):
        return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)

    def FF(self, xx, w, name):
        if name not in self.xx:
            self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
        x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
        self.xx[name] = xx

        r = torch.sigmoid(w.receptance.weight @ x)
        k = torch.square(torch.relu(w.key.weight @ x))
        kv = w.value.weight @ k

        return r * kv

    def SA(self, xx, w, name):
        if name not in self.xx:
            self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
            self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
            self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
        x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
        self.xx[name] = xx

        r = torch.sigmoid(w.receptance.weight @ x)

        k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
        v = w.value.weight @ x
        kv = k * v

        a = self.aa[name] + w.time_first * kv
        b = self.bb[name] + w.time_first * k
        self.aa[name] = w.time_decay * self.aa[name] + kv
        self.bb[name] = w.time_decay * self.bb[name] + k

        rwkv = r * a / (b + RWKV_K_EPS)

        return w.output.weight @ rwkv

    def run(self, ctx):
        w = self.w
        x = w.emb.weight[ctx[-1]]

        for i in range(self.n_layer):
            x = self.LN(x, w.blocks[i].ln1)
            if i == 0 and self.model_type == 'RWKV-ffnPre':
                x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
            else:
                x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
            x = self.LN(x, w.blocks[i].ln2)
            x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')

        x = self.LN(x, w.ln_out)

        if self.hk == None:
            self.hk = (w.head_k.weight @ x).unsqueeze(0)
        else:
            self.hk = torch.cat(
                [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
        if self.hk.shape[0] > self.ctx_len:
            self.hk = self.hk[-self.ctx_len:, :]

        q = w.head_q.weight @ x

        x = w.head.weight @ x
        x = x.cpu().numpy().tolist()

        c = (self.hk @ q) / RWKV_HEAD_QK_DIM
        for i in range(len(c)):
            x[ctx[i]] += c[i]

        return x


================================================
FILE: RWKV-v2-RNN/src/trainer.py
================================================
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math

# import wandb  # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')

logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

log_file = open("mylog.txt", "a")


class TrainerConfig:
    max_epochs = 10
    batch_size = 64
    learning_rate = 4e-4
    betas = (0.9, 0.99)
    eps = 1e-8
    grad_norm_clip = 1.0
    lr_decay = True  # linear warmup followed by cosine decay
    warmup_tokens = 0
    final_tokens = 0
    epoch_save_frequency = 0
    epoch_save_path = 'trained-'
    num_workers = 0  # for DataLoader

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


class Trainer:

    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config
        self.avg_loss = -1
        self.steps = 0

        if 'wandb' in sys.modules:
            cfg = model.config
            for k in config.__dict__:
                setattr(cfg, k, config.__dict__[k])  # combine cfg
            wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
                       datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)

        self.device = 'cpu'
        if torch.cuda.is_available():  # take over whatever gpus are on the system
            self.device = torch.cuda.current_device()

    def get_run_name(self):
        raw_model = self.model.module if hasattr(
            self.model, "module") else self.model
        cfg = raw_model.config
        run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
            cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
        return run_name

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset

            if config.num_workers > 0:
                loader = DataLoader(data, shuffle=False, pin_memory=True,
                                    batch_size=config.batch_size,
                                    num_workers=config.num_workers)
            else:
                loader = DataLoader(data, shuffle=False,
                                    batch_size=config.batch_size,
                                    num_workers=config.num_workers)

            pbar = tqdm(enumerate(loader), total=len(
                loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)

            for it, (x, y) in pbar:
                x = x.to(self.device)  # place data on the correct device
                y = y.to(self.device)

                with torch.set_grad_enabled(is_train):
                    _, loss = model(x, y)  # forward the model

                if is_train:  # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()

                    if config.grad_norm_clip > 0:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), config.grad_norm_clip)

                    optimizer.step()

                    if config.lr_decay:  # decay the learning rate based on our progress
                        # number of tokens processed this step (i.e. label is not -100)
                        self.tokens += (y >= 0).sum()
                        lr_final_factor = config.lr_final / config.learning_rate
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = lr_final_factor + \
                                (1 - lr_final_factor) * float(self.tokens) / \
                                float(config.warmup_tokens)
                            progress = 0
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(
                                max(1, config.final_tokens - config.warmup_tokens))
                            lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor /
                                                                     2) * math.cos(math.pi * progress)  # better 1.0 ~ 0.1
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    now_loss = loss.item()  # report progress
                    self.lr = lr

                    if 'wandb' in sys.modules:
                        wandb.log({"loss": now_loss},
                                  step=self.steps * self.config.batch_size)
                    self.steps += 1

                    if self.avg_loss < 0:
                        self.avg_loss = now_loss
                    else:
                        factor = 1 / (it + 1)
                        self.avg_loss = self.avg_loss * \
                            (1.0 - factor) + now_loss * factor
                    pbar.set_description(
                        f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")

        self.tokens = 0  # counter used for learning rate decay
        for epoch in range(config.max_epochs):

            run_epoch('train')

            log_file.write(
                f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
            log_file.flush()

            if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
                # DataParallel wrappers keep raw model object in .module
                raw_model = self.model.module if hasattr(
                    self.model, "module") else self.model
                torch.save(raw_model.state_dict(),
                           self.config.epoch_save_path + str(epoch+1) + '.pth')


================================================
FILE: RWKV-v2-RNN/src/utils.py
================================================
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset


class Dataset(Dataset):
    def __init__(self, data, ctx_len, epoch_length_fixed):
        print('building token list...', end=' ')
        unique = sorted(list(set(data)))
        # print()
        # for u in unique:
        #     print(u, end=' ')
        # print('\n\n')

        xx = 0
        xxObj = {}
        for u in unique:
            xxObj[xx] = u
            xx += 1
        with open('vocab.json', "w", encoding="utf-16") as vocab_file:
            vocab_file.write(json.dumps(xxObj, ensure_ascii=False))

        data_size, vocab_size = len(data), len(unique)
        print('data has %d tokens, %d unique.' % (data_size, vocab_size))
        self.stoi = {ch: i for i, ch in enumerate(unique)}
        self.itos = {i: ch for i, ch in enumerate(unique)}
        self.ctx_len = ctx_len
        self.epoch_length_fixed = epoch_length_fixed
        self.vocab_size = vocab_size
        self.data = data

    def __len__(self):
        return self.epoch_length_fixed

    def __getitem__(self, idx):
        # cheat: pick a random spot in dataset
        i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
        chunk = self.data[i:i+self.ctx_len+1]
        dix = [self.stoi[s] for s in chunk]
        x = torch.tensor(dix[:-1], dtype=torch.long,
                         device=torch.device('cuda'))
        y = torch.tensor(dix[1:], dtype=torch.long,
                         device=torch.device('cuda'))
        return x, y


class TOKENIZER():
    def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
        with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
            self.word_table = json.load(result_file)

        self.vocab_size = len(self.word_table)

        self.stoi = {v: int(k) for k, v in self.word_table.items()}
        self.itos = {int(k): v for k, v in self.word_table.items()}

        self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]

    def refine_context(self, context):
        context = context.strip().split('\n')
        for c in range(len(context)):
            context[c] = context[c].strip().strip('\u3000').strip('\r')
        context = list(filter(lambda c: c != '', context))
        context = '\n' + ('\n'.join(context)).strip()
        if context == '':
            context = '\n'

        return context

    def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
        # out[self.UNKNOWN_CHAR] = -float('Inf')

        lastChar = int(x[-1])

        probs = F.softmax(torch.tensor(out), dim=-1)

        if self.itos[lastChar] == '\n':
            top_p = top_p_newline
        else:
            top_p = top_p_usual

        sorted_probs, s_index = torch.sort(probs, descending=True)

        # for j in range(30):
        #     pp = sorted_probs[j].item()
        #     if pp < 0.005:
        #         break
        #     ss = self.itos[int(s_index[j])].replace('\n','_')
        #     print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
        # print('')

        cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
        cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])

        probs[probs < cutoff] = 0
        # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")

        if temperature != 1.0:
            probs = probs.pow(1.0 / temperature)

        return torch.multinomial(probs, num_samples=1)[0]


def to_float(x):
    return x.cpu().detach().numpy().flatten()[0].astype(float)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


================================================
FILE: RWKV-v2-RNN/train.py
================================================
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import logging
import datetime
import json
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from src.utils import Dataset
import torch
import numpy as np
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

### Step 1: set training data ##########################################################################

datafile = "enwik8"
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'

### Step 2: set model size #############################################################################

ctx_len = 1024        # ===> increase T_MAX in model.py if your ctx_len > 1024
n_layer = 6
n_embd = 512

# 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases)
model_type = 'RWKV'

### Step 3: set batch size #############################################################################

# ===> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
# For example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
batch_size = 12

### Step 4: set learning rate, training mini-epochs #######################################################

lr_init = 6e-4
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
n_epoch = 500
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc.
epoch_save_frequency = 30
epoch_save_path = 'trained-'

epoch_length_fixed = 10000

########################################################################################################

# import src.utils
# src.utils.set_seed(42) # remember to change seed if you load a model

np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)

grad_norm_clip = 1.0
warmup_tokens = 0

betas = (0.9, 0.99)
eps = 4e-9

num_workers = 0

########################################################################################################
# Load data
########################################################################################################

print('loading data... ' + datafile)
train_dataset = Dataset(open(
    datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)

########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':

    model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
                          n_layer=n_layer, n_embd=n_embd)).cuda()

    # # # load a trained model. remember to change random seed
    # m2 = torch.load('trained-61.pth')
    # model.load_state_dict(m2)

    print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
          betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
    tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
                          learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
                          warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
    trainer = Trainer(model, train_dataset, None, tconf)

    trainer.train()

    torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
               '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')


================================================
FILE: RWKV-v3/cuda/timex_cuda.cu
================================================
#include <stdio.h>

// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)

#define F4(A, B) ((float4 *)(A))[(B) >> 2]

template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
                               const F eps, const int B, const int C, const int T) {
    const int i = blockIdx.y;
    const int ij = (B * C) / BF;
    const int t = threadIdx.x << 2;

    __shared__ F ww[Tmax];
    __shared__ F kk[Tmax * BF];
    F4(ww, t) = F4(__w, t + T * (i % C));
    
    #pragma unroll
    for (int j = 0; j < BF; j++) {
        F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
    }
    __syncthreads();

    float4 s[BF];
    #pragma unroll
    for (int j = 0; j < BF; j++) {
        s[j] = {eps, eps, eps, eps};
    }
    const F *__restrict__ const w = ww + T - t - 4;
    for (int u = 0; u <= t; u++) {
        #pragma unroll
        for (int j = 0; j < BF; j++) {
            const F x = kk[u + Tmax * j];
            s[j].x += w[u + 3] * x;
            s[j].y += w[u + 2] * x;
            s[j].z += w[u + 1] * x;
            s[j].w += w[u + 0] * x;
        }
    }
    #pragma unroll
    for (int j = 0; j < BF; j++) {
        const F *__restrict__ const k = kk + Tmax * j;
        s[j].y += w[t + 3] * k[t + 1];
        s[j].z += w[t + 2] * k[t + 1];
        s[j].z += w[t + 3] * k[t + 2];
        s[j].w += w[t + 1] * k[t + 1];
        s[j].w += w[t + 2] * k[t + 2];
        s[j].w += w[t + 3] * k[t + 3];
        F4(x, t + T * (i + ij * j)) = s[j];
    }
}

template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
                                F *__restrict__ const gw, F *__restrict__ const gk,
                                const int B, const int C, const int T) {
    const int i = blockIdx.y;
    const int t = threadIdx.x << 2;

    __shared__ F k[Tmax];
    __shared__ F gg[Tmax];
    F4(k, t) = F4(__k, t + T * i);
    F4(gg, t) = F4(__gwk, t + T * i);
    __syncthreads();

    float4 s = {0, 0, 0, 0};

    const F *__restrict__ const g = gg + T - t - 4;
    for (int u = 0; u <= t; u++) {
        F x = k[u];
        s.x += g[u + 3] * x;
        s.y += g[u + 2] * x;
        s.z += g[u + 1] * x;
        s.w += g[u + 0] * x;
    }
    s.y += g[t + 3] * k[t + 1];
    s.z += g[t + 2] * k[t + 1];
    s.z += g[t + 3] * k[t + 2];
    s.w += g[t + 1] * k[t + 1];
    s.w += g[t + 2] * k[t + 2];
    s.w += g[t + 3] * k[t + 3];
    F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
    dim3 gridDim(1, B * C / BF);
    dim3 blockDim(T >> 2);
    kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}

template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
                                F *__restrict__ const gw, F *__restrict__ const gk,
                                const int B, const int C, const int T) {
    const int i = blockIdx.y;
    const int ij = (B * C) / BB;
    const int t = threadIdx.x << 2;

    __shared__ F w[Tmax];
    __shared__ F kk[Tmax * BB];
    __shared__ F gg[Tmax * BB];
    F4(w, t) = F4(__w, t + T * (i % C));

    #pragma unroll
    for (int j = 0; j < BB; j++) {
        F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
        F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
    }
    __syncthreads();

    float4 s[BB];
    #pragma unroll
    for (int j = 0; j < BB; j++) {
        s[j] = {0, 0, 0, 0};
    }

    for (int u = 0; u <= t; u++) {
        #pragma unroll
        for (int j = 0; j < BB; j++) {
            const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
            F x = kk[u + Tmax * j];
            s[j].x += g[u + 3] * x;
            s[j].y += g[u + 2] * x;
            s[j].z += g[u + 1] * x;
            s[j].w += g[u + 0] * x;
        }
    }
    #pragma unroll
    for (int j = 0; j < BB; j++) {
        const F *__restrict__ const k = kk + Tmax * j;
        const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
        s[j].y += g[t + 3] * k[t + 1];
        s[j].z += g[t + 2] * k[t + 1];
        s[j].z += g[t + 3] * k[t + 2];
        s[j].w += g[t + 1] * k[t + 1];
        s[j].w += g[t + 2] * k[t + 2];
        s[j].w += g[t + 3] * k[t + 3];
        F4(gw, t + T * (i + ij * j)) = s[j];
    }

    #pragma unroll
    for (int j = 0; j < BB; j++) {
        s[j] = {0, 0, 0, 0};
    }

    for (int u = t + 3; u < T; u++) {
        F x = w[u];
        #pragma unroll
        for (int j = 0; j < BB; j++) {
            const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
            s[j].x += g[2 - u] * x;
            s[j].y += g[3 - u] * x;
            s[j].z += g[4 - u] * x;
            s[j].w += g[5 - u] * x;
        }        
    }
    #pragma unroll
    for (int j = 0; j < BB; j++) {
        const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
        s[j].x += g[2 - t] * w[t + 0];
        s[j].x += g[1 - t] * w[t + 1];
        s[j].x += g[0 - t] * w[t + 2];
        s[j].y += g[2 - t] * w[t + 1];
        s[j].y += g[1 - t] * w[t + 2];
        s[j].z += g[2 - t] * w[t + 2];
        F4(gk, t + T * (i + ij * j)) = s[j];
    }
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
    dim3 gridDim(1, B * C / BB);
    dim3 blockDim(T >> 2);
    kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}


================================================
FILE: RWKV-v3/cuda/timex_op.cpp
================================================
#include <torch/extension.h>

void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);

void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
    cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
    cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "timex forward");
    m.def("backward", &backward, "timex backward");
}

TORCH_LIBRARY(timex, m) {
    m.def("forward", forward);
    m.def("backward", backward);
}


================================================
FILE: RWKV-v3/run.py
================================================
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)

### Step 1: set model ##################################################################################

ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV'           # 'RWKV' or 'RWKV-ffnPre'

# your trained model
MODEL_NAME = 'trained-1'
WORD_NAME = 'vocab'           # the .json vocab (generated by train.py

# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' '   # here we just set it to [space] for simplicity

RUN_DEVICE = 'cpu'   # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False  # True False - show softmax output

### Step 2: set context ################################################################################

context = "\nIn the"       # ==> this is your prompt

NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500

TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9

########################################################################################################

print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)

########################################################################################################

context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')

for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
    t_begin = time.time_ns()

    src_len = len(context)
    ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
    print(('-' * 30) + context, end='')

    model.clear()
    if TRIAL == 0:
        init_state = types.SimpleNamespace()
        for i in range(src_len):
            x = ctx[:i+1]
            if i == src_len - 1:
                init_state.out = model.run(x)
            else:
                model.run(x)
        model.save(init_state)
    else:
        model.load(init_state)

    for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
        x = ctx[:i+1]
        x = x[-ctx_len:]

        if i == src_len:
            out = copy.deepcopy(init_state.out)
        else:
            out = model.run(x)
        if DEBUG_DEBUG:
            print('model', np.array(x), '==>', np.array(
                out), np.max(out), np.min(out))

        char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
                                       top_p_usual=top_p, top_p_newline=top_p_newline)
        char = char.item()
        print(tokenizer.itos[int(char)], end='', flush=True)
        ctx += [char]
    t_end = time.time_ns()
    print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')


================================================
FILE: RWKV-v3/src/model.py
================================================
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)

RWKV_K_CLAMP = 60  # e^60 = 1e26
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')

########################################################################################################
# CUDA Kernel
########################################################################################################

T_MAX = 1024          # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 4   # set to 8 for best performance
B_GROUP_BACKWARD = 2  # set to 2 for best performance (sometimes 8 is faster)

timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
                  verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])


class TimeX(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, k, B, C, T, eps):
        ctx.B = B
        ctx.C = C
        ctx.T = T
        assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
        w = w.contiguous()
        k = k.contiguous()
        ctx.save_for_backward(w, k)
        wk = torch.empty((B, C, T), device='cuda',
                         memory_format=torch.contiguous_format)
        timex_cuda.forward(w, k, wk, eps, B, C, T)
        return wk

    @staticmethod
    def backward(ctx, gwk):
        assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
        w, k = ctx.saved_tensors
        gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
                         memory_format=torch.contiguous_format)
        gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
                         memory_format=torch.contiguous_format)
        timex_cuda.backward(w, k, gwk.contiguous(), gw,
                            gk, ctx.B, ctx.C, ctx.T)
        return (gw.sum(dim=0), gk, None, None, None, None)

########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################

def RWKV_Init(module, config):  # fancy initialization of all lin & emb layer in the module
    for m in module.modules():
        if not isinstance(m, (nn.Linear, nn.Embedding)):
            continue
        with torch.no_grad():
            name = '[unknown weight]'
            for name, parameter in module.named_parameters():  # find the name of the weight
                if id(m.weight) == id(parameter):
                    break

            shape = m.weight.data.shape
            gain = 1.0
            scale = 1.0  # extra scale for gain

            if isinstance(m, nn.Embedding):
                gain = math.sqrt(max(shape[0], shape[1]))
                if shape[0] == config.vocab_size and shape[1] == config.n_embd:  # token emb?
                    scale = 1e-4
                else:
                    scale = 0

            if isinstance(m, nn.Linear):
                if m.bias is not None:
                    m.bias.data.zero_()
                if shape[0] > shape[1]:
                    gain = math.sqrt(shape[0] / shape[1])
                if shape[0] == config.vocab_size and shape[1] == config.n_embd:  # final projection?
                    scale = 0.5

            if hasattr(m, 'scale_init'):
                scale = m.scale_init

            # print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)

            gain *= scale
            if scale == -999:
                nn.init.eye_(m.weight)
            elif gain == 0:
                # zero init is great for some RWKV matrices
                nn.init.zeros_(m.weight)
            elif gain > 0:
                nn.init.orthogonal_(m.weight, gain=gain)
            else:
                nn.init.normal_(m.weight, mean=0.0, std=-scale)


class RWKV_TimeMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.ctx_len = config.ctx_len
        self.n_embd = config.n_embd

        attn_sz = config.n_embd

        with torch.no_grad(): # fancy init
            self.time_curve = torch.tensor([-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
            self.time_curve = self.time_curve.to('cuda')

            ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
            ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
            
            # fancy time_decay
            decay_speed = torch.ones(attn_sz, 1)
            for h in range(attn_sz):
                decay_speed[h][0] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
            self.time_decay = nn.Parameter(decay_speed)
            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())

            # fancy time_first
            zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5).unsqueeze(1)
            self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3) + zigzag)
            
            # fancy time_mix
            x = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd):
                x[0, 0, i] = i / config.n_embd
            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
            self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
            

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

        self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)

        self.output = nn.Linear(attn_sz, config.n_embd, bias=False)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def forward(self, x):
        B, T, C = x.size() # x = (Batch,Time,Channel)

        # Mix x with the previous timestep to produce xk, xv, xr
        xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        # Use xk, xv, xr to produce k, v, r
        k = self.key(xk).transpose(-1, -2)
        v = self.value(xv).transpose(-1, -2)
        r = self.receptance(xr)

        # RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
        k = torch.clamp(k, max=RWKV_K_CLAMP) # clamp k to avoid overflow
        k = torch.exp(k)
        kv = k * v

        # Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
        self.time_w = torch.cat(
            [torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
        w = torch.exp(self.time_w)

        # Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
        wkv = TimeX.apply(w, kv, B, C, T, 0)
        # RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
        wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)

        rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
        rwkv = self.output(rwkv)
        return rwkv


class RWKV_ChannelMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

        with torch.no_grad(): # fancy init of time_mix
            ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0

            x = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd):
                x[0, 0, i] = i / config.n_embd

            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))

        hidden_sz = 4 * config.n_embd
        self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)

        self.value.scale_init = 0
        self.receptance.scale_init = 0

    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        k = self.key(xk)
        k = torch.square(torch.relu(k))
        kv = self.value(k)

        rkv = torch.sigmoid(self.receptance(xr)) * kv
        return rkv

########################################################################################################
# The GPT Model with our blocks
########################################################################################################


class GPTConfig:
    def __init__(self, vocab_size, ctx_len, **kwargs):
        self.vocab_size = vocab_size
        self.ctx_len = ctx_len
        for k, v in kwargs.items():
            setattr(self, k, v)


class Block(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.config = config
        self.layer_id = layer_id

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        if self.layer_id == 0:
            self.ln0 = nn.LayerNorm(config.n_embd)

        if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
            self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
        else:
            self.att = RWKV_TimeMix(config, layer_id)

        self.ffn = RWKV_ChannelMix(config, layer_id)

    def forward(self, x):
        if self.layer_id == 0:
            x = self.ln0(x)        
        if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
            x = x + self.ffnPre(self.ln1(x))  # better in some cases
        else:
            x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.step = 0
        self.config = config

        self.emb = nn.Embedding(config.vocab_size, config.n_embd)

        self.blocks = nn.Sequential(*[Block(config, i)
                                    for i in range(config.n_layer)])

        self.ln_out = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        if RWKV_HEAD_QK_DIM > 0:
            self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
            self.head_q.scale_init = 0
            self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
            self.head_k.scale_init = 0.1
            self.register_buffer("copy_mask", torch.tril(
                torch.ones(config.ctx_len, config.ctx_len)))

        self.ctx_len = config.ctx_len

        RWKV_Init(self, config)

        logger.info("number of parameters: %e", sum(p.numel()
                    for p in self.parameters()))

    def get_ctx_len(self):
        return self.ctx_len

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear)):
            module.weight.data.normal_(mean=0.0, std=0.01)
        if isinstance(module, (nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=1e-5)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def configure_optimizers(self, train_config):
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()

        for mn, m in self.named_modules():  # here we disable weight_decay
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
                no_decay.add(fpn)

        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(
            inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
            % (str(param_dict.keys() - union_params), )

        optim_groups = [
            {"params": [param_dict[pn]
                        for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]

        optimizer = torch.optim.Adam(
            optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)

        return optimizer

    def forward(self, idx, targets=None):
        self.step += 1
        B, T = idx.size()
        assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
        x = self.emb(idx)

        x = self.blocks(x)

        x = self.ln_out(x)

        if RWKV_HEAD_QK_DIM > 0:
            q = self.head_q(x)[:, :T, :]
            k = self.head_k(x)[:, :T, :]
            c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
            c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)

            c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
            x = self.head(x) + c
        else:
            x = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))

        return x, loss


================================================
FILE: RWKV-v3/src/model_run.py
================================================
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import types
import copy
import torch
import math
from torch.nn import functional as F
import torch.nn as nn

RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')

DEBUG_TIME = False   # True False - show trained time-coeffs

############################################################################################################

RWKV_CFG = types.SimpleNamespace()

class RWKV_ChannelMix(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
        self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))

        hidden_sz = 4 * RWKV_CFG.n_embd
        self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
        self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
        self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)

    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        k = self.key(xk)
        k = torch.square(torch.relu(k))
        kv = self.value(k)
        
        rkv = torch.sigmoid(self.receptance(xr)) * kv
        return rkv

class RWKV_TimeMix(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1))
        self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0)
        self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3))
        
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
        self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
        self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))

        self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
        self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
        self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)

        self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)

    def forward(self, x):
        B, T, C = x.size()

        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        k = self.key(xk).transpose(-1, -2)
        v = self.value(xv).transpose(-1, -2)
        r = self.receptance(xr)

        k = torch.clamp(k, max=RWKV_K_CLAMP)
        k = torch.exp(k)

        kv = k * v

        self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
        w = torch.exp(self.time_w)
        
        w = w[:,-T:].unsqueeze(1)
        wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
        wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS

        rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
        
        rwkv = self.output(rwkv)
        return rwkv

class Block(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
        self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
        if self.layer_id == 0:
            self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)

        if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
            self.ffnPre = RWKV_ChannelMix(layer_id+1000)
        else:
            self.att = RWKV_TimeMix(layer_id)

        self.ffn = RWKV_ChannelMix(layer_id)

    def forward(self, x):
        if self.layer_id == 0:
            x = self.ln0(x)
        if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
            x = x + self.ffnPre(self.ln1(x))
        else:
            x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class RWKV_GPT(nn.Module):
    def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
        global RWKV_CFG
        super().__init__()

        RWKV_CFG.RUN_DEVICE = RUN_DEVICE
        RWKV_CFG.model_type = model_type
        RWKV_CFG.vocab_size = vocab_size
        RWKV_CFG.n_layer = n_layer
        RWKV_CFG.n_embd = n_embd
        RWKV_CFG.ctx_len = ctx_len

        print('\nloading RWKV-GPT', MODEL_NAME)

        self.emb = nn.Embedding(vocab_size, n_embd)

        self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])

        self.ln_out = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

        if RWKV_HEAD_QK_DIM > 0:
            self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
            self.head_q.scale_init = 0
            self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
            self.head_k.scale_init = 0.1
            self.register_buffer("copy_mask", torch.tril(
                torch.ones(ctx_len, ctx_len)))

        self.ctx_len = ctx_len
        self.eval()
        self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
        self.eval()

    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
        
        x = self.emb(idx)
        x = self.blocks(x)
        x = self.ln_out(x)

        if RWKV_HEAD_QK_DIM > 0:
            q = self.head_q(x)[:, :T, :]
            k = self.head_k(x)[:, :T, :]
            c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
            c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)

            c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
            x = self.head(x) + c
        else:
            x = self.head(x)        

        return x

############################################################################################################

class RWKV_RNN():
    def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
        self.RUN_DEVICE = RUN_DEVICE
        self.model_type = model_type
        self.n_layer = n_layer
        self.n_embd = n_embd
        self.ctx_len = ctx_len

        self.w = types.SimpleNamespace()

        w = torch.load(MODEL_NAME + '.pth',
                       map_location=torch.device(RUN_DEVICE))
        for x in w.keys():
            if '.time_' in x:
                w[x] = w[x].squeeze()
            if '.time_decay' in x:
                w[x] = torch.exp(-torch.exp(w[x]))
            if '.time_first' in x:
                w[x] = torch.exp(w[x])
            if DEBUG_TIME and '.time_' in x:
                print(x, w[x].squeeze().cpu().numpy())

            xx = x.split('.')
            here = self.w
            for i in range(len(xx)):
                if xx[i].isdigit():
                    ii = int(xx[i])
                    if ii not in here:
                        here[ii] = types.SimpleNamespace()
                    here = here[ii]
                else:
                    if i == len(xx) - 1:
                        setattr(here, xx[i], w[x])
                    elif not hasattr(here, xx[i]):
                        if xx[i+1].isdigit():
                            setattr(here, xx[i], {})
                        else:
                            setattr(here, xx[i], types.SimpleNamespace())
                    here = getattr(here, xx[i])

        self.clear()

    def clear(self):
        self.xx = {}
        self.aa = {}
        self.bb = {}
        self.hk = None

    def save(self, target):
        target.xx = copy.deepcopy(self.xx)
        target.aa = copy.deepcopy(self.aa)
        target.bb = copy.deepcopy(self.bb)
        target.hk = copy.deepcopy(self.hk)

    def load(self, target):
        self.xx = copy.deepcopy(target.xx)
        self.aa = copy.deepcopy(target.aa)
        self.bb = copy.deepcopy(target.bb)
        self.hk = copy.deepcopy(target.hk)

    def LN(self, xx, w):
        return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)

    def FF(self, xx, w, name):
        if name not in self.xx:
            self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
        xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
        xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
        self.xx[name] = xx

        r = torch.sigmoid(w.receptance.weight @ xr)
        k = torch.square(torch.relu(w.key.weight @ xk))
        kv = w.value.weight @ k

        return r * kv

    def SA(self, xx, w, name):
        if name not in self.xx:
            self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
            self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
            self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)

        xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
        xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
        xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
        self.xx[name] = xx

        r = torch.sigmoid(w.receptance.weight @ xr)

        k = torch.exp(torch.clamp(w.key.weight @ xk, max=RWKV_K_CLAMP))
        v = w.value.weight @ xv
        kv = k * v

        a = self.aa[name] + w.time_first * kv
        b = self.bb[name] + w.time_first * k
        self.aa[name] = w.time_decay * self.aa[name] + kv
        self.bb[name] = w.time_decay * self.bb[name] + k

        rwkv = r * a / (b + RWKV_K_EPS)

        return w.output.weight @ rwkv

    def run(self, ctx):
        w = self.w
        x = w.emb.weight[ctx[-1]]

        for i in range(self.n_layer):
            if i == 0:
                x = self.LN(x, w.blocks[i].ln0)
            if i == 0 and self.model_type == 'RWKV-ffnPre':
                x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
            else:
                x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
            x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')

        x = self.LN(x, w.ln_out)

        if RWKV_HEAD_QK_DIM > 0:
            if self.hk == None:
                self.hk = (w.head_k.weight @ x).unsqueeze(0)
            else:
                self.hk = torch.cat(
                    [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
            if self.hk.shape[0] > self.ctx_len:
                self.hk = self.hk[-self.ctx_len:, :]

            q = w.head_q.weight @ x

            x = w.head.weight @ x
            x = x.cpu().numpy().tolist()

            c = (self.hk @ q) / RWKV_HEAD_QK_DIM
            for i in range(len(c)):
                x[ctx[i]] += c[i]
        else:
            x = w.head.weight @ x
            x = x.cpu().numpy().tolist()

        return x


================================================
FILE: RWKV-v3/src/trainer.py
================================================
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math

# import wandb  # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')

logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

log_file = open("mylog.txt", "a")


class TrainerConfig:
    max_epochs = 10
    batch_size = 64
    learning_rate = 4e-4
    betas = (0.9, 0.99)
    eps = 1e-8
    grad_norm_clip = 1.0
    lr_decay = True  # linear warmup followed by cosine decay
    warmup_tokens = 0
    final_tokens = 0
    epoch_save_frequency = 0
    epoch_save_path = 'trained-'
    num_workers = 0  # for DataLoader

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


class Trainer:

    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config
        self.avg_loss = -1
        self.steps = 0

        if 'wandb' in sys.modules:
            cfg = model.config
            for k in config.__dict__:
                setattr(cfg, k, config.__dict__[k])  # combine cfg
            wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
                       datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)

        self.device = 'cpu'
        if torch.cuda.is_available():  # take over whatever gpus are on the system
            self.device = torch.cuda.current_device()

    def get_run_name(self):
        raw_model = self.model.module if hasattr(
            self.model, "module") else self.model
        cfg = raw_model.config
        run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
            cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
        return run_name

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset

            if config.num_workers > 0:
                loader = DataLoader(data, shuffle=False, pin_memory=True,
                                    batch_size=config.batch_size,
                                    num_workers=config.num_workers)
            else:
                loader = DataLoader(data, shuffle=False,
                                    batch_size=config.batch_size,
                                    num_workers=config.num_workers)

            pbar = tqdm(enumerate(loader), total=len(
                loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)

            for it, (x, y) in pbar:
                x = x.to(self.device)  # place data on the correct device
                y = y.to(self.device)

                with torch.set_grad_enabled(is_train):
                    _, loss = model(x, y)  # forward the model

      
Download .txt
gitextract_nqr7ihm_/

├── .github/
│   └── FUNDING.yml
├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── RWKV-8.md
├── RWKV-v1/
│   ├── src/
│   │   ├── __init__.py
│   │   ├── model.py
│   │   ├── trainer.py
│   │   └── utils.py
│   └── train.py
├── RWKV-v2-RNN/
│   ├── cuda/
│   │   ├── timex_cuda.cu
│   │   └── timex_op.cpp
│   ├── enwik8-vocab.json
│   ├── run.py
│   ├── src/
│   │   ├── model.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   └── train.py
├── RWKV-v3/
│   ├── cuda/
│   │   ├── timex_cuda.cu
│   │   └── timex_op.cpp
│   ├── run.py
│   ├── src/
│   │   ├── model.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── train.py
│   └── verify.py
├── RWKV-v4/
│   ├── 20B_tokenizer.json
│   ├── cuda/
│   │   ├── wkv_cuda.cu
│   │   └── wkv_op.cpp
│   ├── run.py
│   ├── src/
│   │   ├── binidx.py
│   │   ├── model.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── train.py
│   └── verify.py
├── RWKV-v4neo/
│   ├── 20B_tokenizer.json
│   ├── chat.py
│   ├── cuda/
│   │   ├── wkv5_cuda.cu
│   │   ├── wkv5_op.cpp
│   │   ├── wkv_cuda.cu
│   │   ├── wkv_cuda_bf16.cu
│   │   ├── wkv_op.cpp
│   │   └── wkv_op_bf16.cpp
│   ├── img_demoAE.py
│   ├── math_demo/
│   │   └── run.py
│   ├── run.py
│   ├── src/
│   │   ├── __init__.py
│   │   ├── binidx.py
│   │   ├── dataset.py
│   │   ├── model.py
│   │   ├── model_img.py
│   │   ├── model_run.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── train.py
│   └── verify.py
├── RWKV-v5/
│   ├── compute_magic_prime.py
│   ├── cuda/
│   │   ├── wkv5_cuda.cu
│   │   ├── wkv5_op.cpp
│   │   ├── wkv6_cuda.cu
│   │   ├── wkv6_op.cpp
│   │   ├── wkv6state_cuda.cu
│   │   ├── wkv6state_op.cpp
│   │   ├── wkv7_cuda.cu
│   │   └── wkv7_op.cpp
│   ├── demo-training-prepare-v7-pile.sh
│   ├── demo-training-prepare.sh
│   ├── demo-training-run-v7-pile.sh
│   ├── demo-training-run.sh
│   ├── demo.jsonl
│   ├── make_data.py
│   ├── rwkv_v6_demo.py
│   ├── src/
│   │   ├── __init__.py
│   │   ├── binidx.py
│   │   ├── dataset.py
│   │   ├── model.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── tokenizer/
│   │   ├── __init__.py
│   │   └── rwkv_tokenizer.py
│   └── train.py
├── RWKV-v6/
│   └── README.md
├── RWKV-v7/
│   ├── README.md
│   ├── cuda/
│   │   ├── wkv7.cu
│   │   ├── wkv7_op.cpp
│   │   ├── wkv7s.cu
│   │   └── wkv7s_op.cpp
│   ├── misc/
│   │   └── lambada_test.jsonl
│   ├── mmlu_dev_dataset/
│   │   ├── data-00000-of-00001.arrow
│   │   ├── dataset_info.json
│   │   └── state.json
│   ├── mmlu_test_dataset/
│   │   ├── data-00000-of-00001.arrow
│   │   ├── dataset_info.json
│   │   └── state.json
│   ├── rwkv_mmlu_eval.py
│   ├── rwkv_v7_demo.py
│   ├── rwkv_v7_demo_fast.py
│   ├── rwkv_v7_demo_rnn.py
│   ├── rwkv_v7_numpy.py
│   ├── rwkv_v7a_demo.py
│   ├── rwkv_v7b_demo.py
│   ├── rwkv_v8_rc00_demo.py
│   ├── rwkv_v8_rc00_hybrid_demo.py
│   └── train_temp/
│       ├── README.md
│       ├── cuda/
│       │   ├── rwkv7_clampw.cpp
│       │   ├── rwkv7_clampw.cu
│       │   ├── wkv7_cuda.cu
│       │   ├── wkv7_cuda_fp32.cu
│       │   ├── wkv7_op.cpp
│       │   └── wkv7_op_fp32.cpp
│       ├── demo-training-prepare-v7-pile.sh
│       ├── demo-training-prepare.sh
│       ├── demo-training-run-v7-pile.sh
│       ├── demo-training-run.sh
│       ├── rwkv7_train_simplified.py
│       ├── src/
│       │   ├── __init__.py
│       │   ├── binidx.py
│       │   ├── dataset.py
│       │   ├── model.py
│       │   └── trainer.py
│       └── train.py
├── RWKV-v8/
│   ├── 251014_rosa_1bit_layer.py
│   ├── 251014_rosa_1bit_train.py
│   ├── 251014_rosa_onlyemb_train.py
│   ├── 251016_rosa_1bit_run.py
│   ├── 251018_rosa_4bit_run.py
│   ├── 251024_rosaQKV_run.py
│   ├── 251105_reverse_run.py
│   ├── 260212_rosa1bitLM_L12.py
│   ├── 260222_rosa4bitLM_L12.py
│   ├── README.md
│   └── cuda/
│       ├── wkv7_cuda.cu
│       └── wkv7_op.cpp
└── Research/
    └── rwkv7-g0-7.2b.md
Download .txt
SYMBOL INDEX (1025 symbols across 75 files)

FILE: RWKV-v1/src/model.py
  function RWKV_Init (line 16) | def RWKV_Init(module, config): # fancy initialization of all lin & emb l...
  class RWKV_TimeMix (line 56) | class RWKV_TimeMix(nn.Module):
    method __init__ (line 57) | def __init__(self, config, layer_id):
    method forward (line 96) | def forward(self, x):
  class RWKV_ChannelMix (line 129) | class RWKV_ChannelMix(nn.Module):
    method __init__ (line 130) | def __init__(self, config, layer_id):
    method forward (line 144) | def forward(self, x):
  class RWKV_TinyAttn (line 158) | class RWKV_TinyAttn(nn.Module): # extra tiny attention
    method __init__ (line 159) | def __init__(self, config):
    method forward (line 168) | def forward(self, x, mask):
  class RotaryEmbedding (line 192) | class RotaryEmbedding(torch.nn.Module):
    method __init__ (line 193) | def __init__(self, dim, base=10000):
    method forward (line 201) | def forward(self, x, seq_len=None):
  function rotate_half (line 211) | def rotate_half(x):
  function apply_rotary_pos_emb (line 216) | def apply_rotary_pos_emb(q, k, cos, sin):
  class MHA_rotary (line 220) | class MHA_rotary(nn.Module):
    method __init__ (line 221) | def __init__(self, config, layer_id, time_shift = False):
    method forward (line 243) | def forward(self, x):
  class GeGLU (line 270) | class GeGLU(torch.nn.Module):
    method __init__ (line 271) | def __init__(self, config, layer_id, time_shift = False):
    method forward (line 283) | def forward(self, x):
  class MHA_pro (line 297) | class MHA_pro(nn.Module):
    method __init__ (line 298) | def __init__(self, config, layer_id):
    method forward (line 324) | def forward(self, x):
  class RMSNorm (line 361) | class RMSNorm(nn.Module):
    method __init__ (line 362) | def __init__(self, d):
    method forward (line 367) | def forward(self, x):
  class FixedNorm (line 372) | class FixedNorm(nn.Module):
    method __init__ (line 373) | def __init__(self, d):
    method forward (line 377) | def forward(self, x):
  class GPTConfig (line 384) | class GPTConfig:
    method __init__ (line 385) | def __init__(self, vocab_size, ctx_len, **kwargs):
  class Block (line 391) | class Block(nn.Module):
    method __init__ (line 392) | def __init__(self, config, layer_id):
    method forward (line 417) | def forward(self, x):
  class GPT (line 424) | class GPT(nn.Module):
    method __init__ (line 425) | def __init__(self, config):
    method get_ctx_len (line 452) | def get_ctx_len(self):
    method _init_weights (line 455) | def _init_weights(self, module):
    method configure_optimizers (line 461) | def configure_optimizers(self, train_config):
    method forward (line 494) | def forward(self, idx, targets=None):

FILE: RWKV-v1/src/trainer.py
  class TrainerConfig (line 14) | class TrainerConfig:
    method __init__ (line 29) | def __init__(self, **kwargs):
  class Trainer (line 33) | class Trainer:
    method __init__ (line 35) | def __init__(self, model, train_dataset, test_dataset, config):
    method get_run_name (line 54) | def get_run_name(self):
    method train (line 60) | def train(self):

FILE: RWKV-v1/src/utils.py
  function top_k_logits (line 7) | def top_k_logits(logits, k):
  function top_p_probs (line 13) | def top_p_probs(probs, p):
  function sample_logits (line 27) | def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, ...
  function set_seed (line 46) | def set_seed(seed):

FILE: RWKV-v1/train.py
  class Dataset (line 79) | class Dataset(Dataset):
    method __init__ (line 80) | def __init__(self, data, model_level, ctx_len):
    method __len__ (line 110) | def __len__(self):
    method __getitem__ (line 113) | def __getitem__(self, idx):

FILE: RWKV-v2-RNN/cuda/timex_op.cpp
  function forward (line 6) | void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x,...
  function backward (line 9) | void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Ten...
  function PYBIND11_MODULE (line 13) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 18) | TORCH_LIBRARY(timex, m) {

FILE: RWKV-v2-RNN/src/model.py
  class TimeX (line 26) | class TimeX(torch.autograd.Function):
    method forward (line 28) | def forward(ctx, w, k, B, C, T, eps):
    method backward (line 42) | def backward(ctx, gwk):
  function RWKV_Init (line 63) | def RWKV_Init(module, config):  # fancy initialization of all lin & emb ...
  class RWKV_TimeMix (line 109) | class RWKV_TimeMix(nn.Module):
    method __init__ (line 110) | def __init__(self, config, layer_id):
    method forward (line 162) | def forward(self, x):
  class RWKV_ChannelMix (line 189) | class RWKV_ChannelMix(nn.Module):
    method __init__ (line 190) | def __init__(self, config, layer_id):
    method forward (line 210) | def forward(self, x):
  class GPTConfig (line 225) | class GPTConfig:
    method __init__ (line 226) | def __init__(self, vocab_size, ctx_len, **kwargs):
  class Block (line 233) | class Block(nn.Module):
    method __init__ (line 234) | def __init__(self, config, layer_id):
    method forward (line 249) | def forward(self, x):
  class GPT (line 260) | class GPT(nn.Module):
    method __init__ (line 261) | def __init__(self, config):
    method get_ctx_len (line 288) | def get_ctx_len(self):
    method _init_weights (line 291) | def _init_weights(self, module):
    method configure_optimizers (line 299) | def configure_optimizers(self, train_config):
    method forward (line 327) | def forward(self, idx, targets=None):

FILE: RWKV-v2-RNN/src/model_run.py
  class RWKV_RNN (line 13) | class RWKV_RNN():
    method __init__ (line 14) | def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd...
    method clear (line 55) | def clear(self):
    method save (line 61) | def save(self, target):
    method load (line 67) | def load(self, target):
    method LN (line 73) | def LN(self, xx, w):
    method FF (line 76) | def FF(self, xx, w, name):
    method SA (line 88) | def SA(self, xx, w, name):
    method run (line 111) | def run(self, ctx):

FILE: RWKV-v2-RNN/src/trainer.py
  class TrainerConfig (line 30) | class TrainerConfig:
    method __init__ (line 44) | def __init__(self, **kwargs):
  class Trainer (line 49) | class Trainer:
    method __init__ (line 51) | def __init__(self, model, train_dataset, test_dataset, config):
    method get_run_name (line 70) | def get_run_name(self):
    method train (line 78) | def train(self):

FILE: RWKV-v2-RNN/src/utils.py
  class Dataset (line 16) | class Dataset(Dataset):
    method __init__ (line 17) | def __init__(self, data, ctx_len, epoch_length_fixed):
    method __len__ (line 42) | def __len__(self):
    method __getitem__ (line 45) | def __getitem__(self, idx):
  class TOKENIZER (line 57) | class TOKENIZER():
    method __init__ (line 58) | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
    method refine_context (line 69) | def refine_context(self, context):
    method sample_logits (line 80) | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=...
  function to_float (line 114) | def to_float(x):
  function set_seed (line 118) | def set_seed(seed):

FILE: RWKV-v3/cuda/timex_op.cpp
  function forward (line 6) | void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x,...
  function backward (line 9) | void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Ten...
  function PYBIND11_MODULE (line 13) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 18) | TORCH_LIBRARY(timex, m) {

FILE: RWKV-v3/src/model.py
  class TimeX (line 31) | class TimeX(torch.autograd.Function):
    method forward (line 33) | def forward(ctx, w, k, B, C, T, eps):
    method backward (line 47) | def backward(ctx, gwk):
  function RWKV_Init (line 62) | def RWKV_Init(module, config):  # fancy initialization of all lin & emb ...
  class RWKV_TimeMix (line 108) | class RWKV_TimeMix(nn.Module):
    method __init__ (line 109) | def __init__(self, config, layer_id):
    method forward (line 156) | def forward(self, x):
  class RWKV_ChannelMix (line 190) | class RWKV_ChannelMix(nn.Module):
    method __init__ (line 191) | def __init__(self, config, layer_id):
    method forward (line 215) | def forward(self, x):
  class GPTConfig (line 232) | class GPTConfig:
    method __init__ (line 233) | def __init__(self, vocab_size, ctx_len, **kwargs):
  class Block (line 240) | class Block(nn.Module):
    method __init__ (line 241) | def __init__(self, config, layer_id):
    method forward (line 259) | def forward(self, x):
  class GPT (line 270) | class GPT(nn.Module):
    method __init__ (line 271) | def __init__(self, config):
    method get_ctx_len (line 299) | def get_ctx_len(self):
    method _init_weights (line 302) | def _init_weights(self, module):
    method configure_optimizers (line 310) | def configure_optimizers(self, train_config):
    method forward (line 338) | def forward(self, idx, targets=None):

FILE: RWKV-v3/src/model_run.py
  class RWKV_ChannelMix (line 23) | class RWKV_ChannelMix(nn.Module):
    method __init__ (line 24) | def __init__(self, layer_id):
    method forward (line 37) | def forward(self, x):
  class RWKV_TimeMix (line 49) | class RWKV_TimeMix(nn.Module):
    method __init__ (line 50) | def __init__(self, layer_id):
    method forward (line 68) | def forward(self, x):
  class Block (line 97) | class Block(nn.Module):
    method __init__ (line 98) | def __init__(self, layer_id):
    method forward (line 114) | def forward(self, x):
  class RWKV_GPT (line 124) | class RWKV_GPT(nn.Module):
    method __init__ (line 125) | def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_l...
    method forward (line 158) | def forward(self, idx):
  class RWKV_RNN (line 181) | class RWKV_RNN():
    method __init__ (line 182) | def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd...
    method clear (line 223) | def clear(self):
    method save (line 229) | def save(self, target):
    method load (line 235) | def load(self, target):
    method LN (line 241) | def LN(self, xx, w):
    method FF (line 244) | def FF(self, xx, w, name):
    method SA (line 257) | def SA(self, xx, w, name):
    method run (line 283) | def run(self, ctx):

FILE: RWKV-v3/src/trainer.py
  class TrainerConfig (line 30) | class TrainerConfig:
    method __init__ (line 44) | def __init__(self, **kwargs):
  class Trainer (line 49) | class Trainer:
    method __init__ (line 51) | def __init__(self, model, train_dataset, test_dataset, config):
    method get_run_name (line 70) | def get_run_name(self):
    method train (line 78) | def train(self):

FILE: RWKV-v3/src/utils.py
  class Dataset (line 16) | class Dataset(Dataset):
    method __init__ (line 17) | def __init__(self, data, ctx_len, epoch_length_fixed):
    method __len__ (line 42) | def __len__(self):
    method __getitem__ (line 45) | def __getitem__(self, idx):
  class TOKENIZER (line 57) | class TOKENIZER():
    method __init__ (line 58) | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
    method refine_context (line 69) | def refine_context(self, context):
    method sample_logits (line 80) | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=...
  function to_float (line 114) | def to_float(x):
  function set_seed (line 118) | def set_seed(seed):

FILE: RWKV-v4/cuda/wkv_op.cpp
  function forward (line 6) | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::T...
  function backward (line 9) | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::...
  function PYBIND11_MODULE (line 13) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 18) | TORCH_LIBRARY(wkv, m) {

FILE: RWKV-v4/src/binidx.py
  function print_rank_0 (line 10) | def print_rank_0(*message):
  function _warmup_mmap_file (line 18) | def _warmup_mmap_file(path):
  function code (line 35) | def code(dtype):
  function index_file_path (line 41) | def index_file_path(prefix_path):
  function data_file_path (line 44) | def data_file_path(prefix_path):
  class MMapIndexedDataset (line 47) | class MMapIndexedDataset(torch.utils.data.Dataset):
    class Index (line 48) | class Index(object):
      method __init__ (line 51) | def __init__(self, path, skip_warmup=False):
      method __del__ (line 96) | def __del__(self):
      method dtype (line 101) | def dtype(self):
      method sizes (line 105) | def sizes(self):
      method doc_idx (line 109) | def doc_idx(self):
      method __getitem__ (line 113) | def __getitem__(self, i):
      method __len__ (line 116) | def __len__(self):
    method __init__ (line 119) | def __init__(self, path, skip_warmup=False):
    method __getstate__ (line 128) | def __getstate__(self):
    method __setstate__ (line 131) | def __setstate__(self, state):
    method _do_init (line 134) | def _do_init(self, path, skip_warmup):
    method __del__ (line 148) | def __del__(self):
    method __len__ (line 153) | def __len__(self):
    method __getitem__ (line 157) | def __getitem__(self, idx):
    method get (line 179) | def get(self, idx, offset=0, length=None):
    method sizes (line 195) | def sizes(self):
    method doc_idx (line 199) | def doc_idx(self):
    method get_doc_idx (line 202) | def get_doc_idx(self):
    method set_doc_idx (line 205) | def set_doc_idx(self, doc_idx_):
    method supports_prefetch (line 209) | def supports_prefetch(self):
    method exists (line 213) | def exists(path):

FILE: RWKV-v4/src/model.py
  class L2Wrap (line 21) | class L2Wrap(torch.autograd.Function):
    method forward (line 23) | def forward(ctx, loss, y):
    method backward (line 27) | def backward(ctx, grad_output):
  class WKV (line 47) | class WKV(torch.autograd.Function):
    method forward (line 49) | def forward(ctx, B, T, C, w, u, k, v):
    method backward (line 76) | def backward(ctx, gy):
  function RUN_CUDA (line 100) | def RUN_CUDA(B, T, C, w, u, k, v):
  function RWKV_Init (line 107) | def RWKV_Init(model, args):  # fancy initialization of all lin & emb lay...
  class RWKV_TimeMix (line 164) | class RWKV_TimeMix(torch.jit.ScriptModule):
    method __init__ (line 165) | def __init__(self, config, layer_id):
    method jit_func (line 209) | def jit_func(self, x):
    method forward (line 225) | def forward(self, x):
  class RWKV_ChannelMix (line 235) | class RWKV_ChannelMix(torch.jit.ScriptModule):
    method __init__ (line 236) | def __init__(self, config, layer_id):
    method forward (line 261) | def forward(self, x):
  class GPTConfig (line 278) | class GPTConfig:
    method __init__ (line 279) | def __init__(self, vocab_size, ctx_len, **kwargs):
  class Block (line 286) | class Block(nn.Module):
    method __init__ (line 287) | def __init__(self, config, layer_id):
    method forward (line 305) | def forward(self, x):
  class GPT (line 316) | class GPT(nn.Module):
    method __init__ (line 317) | def __init__(self, config):
    method get_ctx_len (line 349) | def get_ctx_len(self):
    method _init_weights (line 352) | def _init_weights(self, module):
    method configure_optimizers (line 360) | def configure_optimizers(self, train_config):
    method forward (line 382) | def forward(self, idx, targets=None):

FILE: RWKV-v4/src/model_run.py
  class WKV (line 29) | class WKV(torch.autograd.Function):
    method forward (line 31) | def forward(ctx, B, T, C, w, u, k, v):
    method backward (line 58) | def backward(ctx, gy):
  function RUN_CUDA (line 82) | def RUN_CUDA(B, T, C, w, u, k, v):
  class RWKV_ChannelMix (line 89) | class RWKV_ChannelMix(nn.Module):
    method __init__ (line 90) | def __init__(self, layer_id):
    method forward (line 103) | def forward(self, x):
  class RWKV_TimeMix (line 115) | class RWKV_TimeMix(nn.Module):
    method __init__ (line 116) | def __init__(self, layer_id):
    method forward (line 133) | def forward(self, x):
  class Block (line 150) | class Block(nn.Module):
    method __init__ (line 151) | def __init__(self, layer_id):
    method forward (line 167) | def forward(self, x):
  class RWKV_GPT (line 177) | class RWKV_GPT(nn.Module):
    method __init__ (line 178) | def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_l...
    method forward (line 211) | def forward(self, idx):
  class RWKV_RNN (line 240) | class RWKV_RNN(): # this is running in FP32 at this moment
    method __init__ (line 241) | def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd...
    method clear (line 281) | def clear(self):
    method save (line 288) | def save(self, target):
    method load (line 295) | def load(self, target):
    method LN (line 302) | def LN(self, xx, w):
    method FF (line 305) | def FF(self, xx, w, name):
    method SA (line 318) | def SA(self, xx, w, name):
    method run (line 356) | def run(self, ctx):

FILE: RWKV-v4/src/trainer.py
  class TrainerConfig (line 28) | class TrainerConfig:
    method __init__ (line 40) | def __init__(self, **kwargs):
  class Trainer (line 46) | class Trainer(LightningLite):
    method get_run_name (line 48) | def get_run_name(self):
    method run (line 56) | def run(self, m_cfg, train_dataset, test_dataset, config):

FILE: RWKV-v4/src/utils.py
  class Dataset (line 18) | class Dataset(Dataset):
    method __init__ (line 19) | def __init__(self, data, ctx_len, epoch_length_fixed):
    method __len__ (line 55) | def __len__(self):
    method __getitem__ (line 58) | def __getitem__(self, idx):
  class TOKENIZER (line 75) | class TOKENIZER():
    method __init__ (line 76) | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
    method refine_context (line 98) | def refine_context(self, context):
    method sample_logits (line 108) | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=...
  function to_float (line 145) | def to_float(x):
  function set_seed (line 149) | def set_seed(seed):

FILE: RWKV-v4neo/chat.py
  function run_rnn (line 145) | def run_rnn(tokens, newline_adj = 0):
  function save_all_stat (line 163) | def save_all_stat(srv, name, last_out):
  function load_all_stat (line 170) | def load_all_stat(srv, name):
  function reply_msg (line 194) | def reply_msg(msg):
  function on_message (line 197) | def on_message(message):

FILE: RWKV-v4neo/cuda/wkv5_op.cpp
  function forward (line 8) | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &...
  function backward (line 11) | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor ...
  function PYBIND11_MODULE (line 14) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 19) | TORCH_LIBRARY(wkv5, m) {

FILE: RWKV-v4neo/cuda/wkv_op.cpp
  function forward (line 6) | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::T...
  function backward (line 9) | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::...
  function PYBIND11_MODULE (line 13) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 18) | TORCH_LIBRARY(wkv, m) {

FILE: RWKV-v4neo/cuda/wkv_op_bf16.cpp
  function forward (line 8) | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::T...
  function backward (line 11) | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::...
  function PYBIND11_MODULE (line 17) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 22) | TORCH_LIBRARY(wkv, m) {

FILE: RWKV-v4neo/img_demoAE.py
  class ToBinary (line 22) | class ToBinary(torch.autograd.Function):
    method forward (line 24) | def forward(ctx, x):
    method backward (line 28) | def backward(ctx, grad_output):
  class R_ENCODER (line 31) | class R_ENCODER(nn.Module):
    method __init__ (line 32) | def __init__(self, args):
    method forward (line 62) | def forward(self, img):
  class R_DECODER (line 84) | class R_DECODER(nn.Module):
    method __init__ (line 85) | def __init__(self, args):
    method forward (line 113) | def forward(self, code):

FILE: RWKV-v4neo/math_demo/run.py
  class TOKENIZER (line 23) | class TOKENIZER():
    method __init__ (line 24) | def __init__(self):
    method encode (line 30) | def encode(self, x):
    method decode (line 33) | def decode(self, x):
  class RWKV_RNN (line 40) | class RWKV_RNN(torch.jit.ScriptModule):
    method __init__ (line 41) | def __init__(self, args):
    method layer_norm (line 68) | def layer_norm(self, x, w):
    method channel_mixing (line 72) | def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, ...
    method time_mixing (line 81) | def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mi...
    method forward (line 109) | def forward(self, token, state):

FILE: RWKV-v4neo/run.py
  function record_time (line 168) | def record_time(name):

FILE: RWKV-v4neo/src/binidx.py
  function print_rank_0 (line 10) | def print_rank_0(*message):
  function _warmup_mmap_file (line 19) | def _warmup_mmap_file(path):
  function code (line 36) | def code(dtype):
  function index_file_path (line 42) | def index_file_path(prefix_path):
  function data_file_path (line 45) | def data_file_path(prefix_path):
  class MMapIndexedDataset (line 48) | class MMapIndexedDataset(torch.utils.data.Dataset):
    class Index (line 49) | class Index(object):
      method writer (line 53) | def writer(cls, path, dtype):
      method __init__ (line 104) | def __init__(self, path, skip_warmup=False):
      method __del__ (line 149) | def __del__(self):
      method dtype (line 154) | def dtype(self):
      method sizes (line 158) | def sizes(self):
      method doc_idx (line 162) | def doc_idx(self):
      method __getitem__ (line 166) | def __getitem__(self, i):
      method __len__ (line 169) | def __len__(self):
    method __init__ (line 172) | def __init__(self, path, skip_warmup=False):
    method __getstate__ (line 181) | def __getstate__(self):
    method __setstate__ (line 184) | def __setstate__(self, state):
    method _do_init (line 187) | def _do_init(self, path, skip_warmup):
    method __del__ (line 201) | def __del__(self):
    method __len__ (line 206) | def __len__(self):
    method __getitem__ (line 210) | def __getitem__(self, idx):
    method get (line 232) | def get(self, idx, offset=0, length=None):
    method sizes (line 248) | def sizes(self):
    method doc_idx (line 252) | def doc_idx(self):
    method get_doc_idx (line 255) | def get_doc_idx(self):
    method set_doc_idx (line 258) | def set_doc_idx(self, doc_idx_):
    method supports_prefetch (line 262) | def supports_prefetch(self):
    method exists (line 266) | def exists(path):

FILE: RWKV-v4neo/src/dataset.py
  class MyDataset (line 14) | class MyDataset(Dataset):
    method __init__ (line 15) | def __init__(self, args):
    method __len__ (line 104) | def __len__(self):
    method __getitem__ (line 107) | def __getitem__(self, idx):

FILE: RWKV-v4neo/src/model.py
  function __nop (line 25) | def __nop(ob):
  class WKV (line 47) | class WKV(torch.autograd.Function):
    method forward (line 49) | def forward(ctx, B, T, C, w, u, k, v):
    method backward (line 64) | def backward(ctx, gy):
    method forward (line 83) | def forward(ctx, B, T, C, w, u, k, v):
    method backward (line 109) | def backward(ctx, gy):
  class WKV (line 81) | class WKV(torch.autograd.Function):
    method forward (line 49) | def forward(ctx, B, T, C, w, u, k, v):
    method backward (line 64) | def backward(ctx, gy):
    method forward (line 83) | def forward(ctx, B, T, C, w, u, k, v):
    method backward (line 109) | def backward(ctx, gy):
  function RUN_CUDA (line 134) | def RUN_CUDA(B, T, C, w, u, k, v):
  class RWKV_TimeMix_RWKV5_Preview (line 139) | class RWKV_TimeMix_RWKV5_Preview(MyModule):
    method __init__ (line 140) | def __init__(self, args, layer_id):
    method jit_func (line 194) | def jit_func(self, x):
    method jit_func_2 (line 211) | def jit_func_2(self, r, k, v, g, w, wk, wb, ws):
    method jit_func (line 232) | def jit_func(self, x):
    method jit_func_2 (line 247) | def jit_func_2(self, r, k, v, w, wk, wb, ws):
    method forward (line 267) | def forward(self, x):
  class WKV_5 (line 320) | class WKV_5(torch.autograd.Function):
    method forward (line 322) | def forward(ctx, B, T, C, H, r, k, v, w, u):
    method backward (line 347) | def backward(ctx, gy):
  function RUN_CUDA_RWKV5 (line 366) | def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
  class RWKV_TimeMix_RWKV5 (line 371) | class RWKV_TimeMix_RWKV5(MyModule):
    method __init__ (line 372) | def __init__(self, args, layer_id):
    method jit_func (line 420) | def jit_func(self, x):
    method jit_func_2 (line 437) | def jit_func_2(self, x, g):
    method forward (line 445) | def forward(self, x):
  class RWKV_TimeMix (line 460) | class RWKV_TimeMix(MyModule):
    method __init__ (line 461) | def __init__(self, args, layer_id):
    method jit_func (line 511) | def jit_func(self, x):
    method forward (line 522) | def forward(self, x):
    method QKV (line 530) | def QKV(self, q, k, v):
    method jit_funcQKV (line 538) | def jit_funcQKV(self, x):
    method forward (line 555) | def forward(self, x):
  class RWKV_ChannelMix (line 564) | class RWKV_ChannelMix(MyModule):
    method __init__ (line 565) | def __init__(self, args, layer_id):
    method forward (line 584) | def forward(self, x):
  class MishGLU (line 593) | class MishGLU(MyModule):
    method __init__ (line 594) | def __init__(self, args, layer_id):
    method forward (line 614) | def forward(self, x):
  class Block (line 627) | class Block(nn.Module):
    method __init__ (line 628) | def __init__(self, args, layer_id):
    method forward (line 668) | def forward(self, x, x_emb=None):
  class L2Wrap (line 700) | class L2Wrap(torch.autograd.Function):
    method forward (line 702) | def forward(ctx, loss, y):
    method backward (line 707) | def backward(ctx, grad_output):
  class RWKV (line 717) | class RWKV(pl.LightningModule):
    method __init__ (line 718) | def __init__(self, args):
    method configure_optimizers (line 747) | def configure_optimizers(self):
    method deepspeed_offload (line 815) | def deepspeed_offload(self) -> bool:
    method forward (line 822) | def forward(self, idx):
    method training_step (line 866) | def training_step(self, batch, batch_idx):
    method training_step_end (line 907) | def training_step_end(self, batch_parts):
    method generate_init_weight (line 913) | def generate_init_weight(self):

FILE: RWKV-v4neo/src/model_img.py
  function __nop (line 18) | def __nop(ob):
  class L2pooling (line 27) | class L2pooling(nn.Module):
    method __init__ (line 28) | def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
    method forward (line 40) | def forward(self, input):
  class DISTS (line 52) | class DISTS(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, load_weights=True):
    method forward_once (line 99) | def forward_once(self, x):
    method forward (line 113) | def forward(self, x, y, require_grad=False, batch_average=False):
    class ToBinary (line 150) | class ToBinary(torch.autograd.Function):
      method forward (line 152) | def forward(ctx, x):#, noise_scale):
      method backward (line 161) | def backward(ctx, grad_output):
  class R_ENCODER (line 166) | class R_ENCODER(MyModule):
    method __init__ (line 167) | def __init__(self, args):
    method forward (line 203) | def forward(self, img):
  class R_DECODER (line 229) | class R_DECODER(MyModule):
    method __init__ (line 230) | def __init__(self, args):
    method forward (line 264) | def forward(self, code):
  function cosine_loss (line 289) | def cosine_loss(x, y):
  class RWKV_IMG (line 294) | class RWKV_IMG(pl.LightningModule):
    method __init__ (line 295) | def __init__(self, args):
    method configure_optimizers (line 330) | def configure_optimizers(self):
    method deepspeed_offload (line 359) | def deepspeed_offload(self) -> bool:
    method forward (line 366) | def forward(self, img):
    method training_step (line 372) | def training_step(self, batch, batch_idx):
    method training_step_end (line 401) | def training_step_end(self, batch_parts):
    method generate_init_weight (line 406) | def generate_init_weight(self):

FILE: RWKV-v4neo/src/model_run.py
  function __nop (line 13) | def __nop(ob):
  class RWKV_RNN (line 35) | class RWKV_RNN(MyModule):
    method __init__ (line 36) | def __init__(self, args):
    method LN (line 116) | def LN(self, x, w):
    method FF (line 122) | def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
    method SA (line 143) | def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time...
    method forward (line 195) | def forward(self, ctx, state, preprocess_only = False):

FILE: RWKV-v4neo/src/trainer.py
  function my_save (line 7) | def my_save(args, trainer, dd, ff):
  class train_callback (line 25) | class train_callback(pl.Callback):
    method __init__ (line 26) | def __init__(self, args):
    method on_train_batch_start (line 30) | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
    method on_train_batch_end (line 116) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
    method on_train_epoch_start (line 159) | def on_train_epoch_start(self, trainer, pl_module):
    method on_train_epoch_end (line 171) | def on_train_epoch_end(self, trainer, pl_module):
  function generate_init_weight (line 203) | def generate_init_weight(model, init_weight_name):

FILE: RWKV-v4neo/src/utils.py
  function record_time (line 9) | def record_time(name):
  class TOKENIZER (line 16) | class TOKENIZER():
    method __init__ (line 17) | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
    method refine_context (line 39) | def refine_context(self, context):
    method sample_logits (line 49) | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=...
  function MaybeIsPrime (line 84) | def MaybeIsPrime(number):
  function FermatPrimalityTest (line 91) | def FermatPrimalityTest(number):
  function MillerRabinPrimalityTest (line 102) | def MillerRabinPrimalityTest(number):

FILE: RWKV-v5/compute_magic_prime.py
  function is_prime (line 5) | def is_prime(n):

FILE: RWKV-v5/cuda/wkv5_op.cpp
  function forward (line 8) | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &...
  function backward (line 11) | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor ...
  function PYBIND11_MODULE (line 14) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 19) | TORCH_LIBRARY(wkv5, m) {

FILE: RWKV-v5/cuda/wkv6_op.cpp
  function forward (line 8) | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &...
  function backward (line 11) | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor ...
  function TORCH_LIBRARY (line 15) | TORCH_LIBRARY(wkv6, m) {

FILE: RWKV-v5/cuda/wkv6state_op.cpp
  function forward (line 8) | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &...
  function backward (line 11) | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor ...
  function PYBIND11_MODULE (line 14) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  function TORCH_LIBRARY (line 19) | TORCH_LIBRARY(wkv6state, m) {

FILE: RWKV-v5/cuda/wkv7_op.cpp
  function forward (line 7) | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch...
  function backward (line 14) | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torc...
  function TORCH_LIBRARY (line 21) | TORCH_LIBRARY(wind_backstepping, m) {
  function TORCH_LIBRARY_IMPL (line 26) | TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {

FILE: RWKV-v5/make_data.py
  function index_file_path (line 36) | def index_file_path(prefix_path):
  function data_file_path (line 38) | def data_file_path(prefix_path):
  class MMapIndexedDatasetBuilder (line 40) | class MMapIndexedDatasetBuilder(object):
    method __init__ (line 41) | def __init__(self, out_file, dtype=np.uint16):
    method add_item (line 46) | def add_item(self, np_array):
    method end_document (line 50) | def end_document(self):
    method finalize (line 52) | def finalize(self, index_file):
  function add_raw (line 57) | def add_raw(raw):
  function is_prime (line 69) | def is_prime(n):

FILE: RWKV-v5/rwkv_v6_demo.py
  class WKV_6 (line 274) | class WKV_6(torch.autograd.Function):
    method forward (line 276) | def forward(ctx, B, T, C, H, r, k, v, w, u): # forward: r, k, v, w, u ...
    method backward (line 299) | def backward(ctx, gy): # backward: gy => gr, gk, gv, gw, gu
  function RUN_CUDA_RWKV6 (line 317) | def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
  class RWKV_Tmix_x060 (line 324) | class RWKV_Tmix_x060(nn.Module):
    method __init__ (line 325) | def __init__(self, args, layer_id):
    method jit_func (line 377) | def jit_func(self, x):
    method jit_func_2 (line 403) | def jit_func_2(self, x, g):
    method forward (line 411) | def forward(self, x):
  class RWKV_CMix_x060 (line 424) | class RWKV_CMix_x060(nn.Module):
    method __init__ (line 425) | def __init__(self, args, layer_id):
    method forward (line 443) | def forward(self, x):
  class Block (line 457) | class Block(nn.Module):
    method __init__ (line 458) | def __init__(self, args, layer_id):
    method forward (line 472) | def forward(self, x):
  class RWKV (line 486) | class RWKV(nn.Module):
    method __init__ (line 487) | def __init__(self, args):
    method forward (line 506) | def forward(self, idx):
    method init_params (line 518) | def init_params(self):
  class RWKV_TOKENIZER (line 583) | class RWKV_TOKENIZER():
    method __init__ (line 587) | def __init__(self, file_name):
    method encodeBytes (line 618) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 639) | def decodeBytes(self, tokens):
    method encode (line 642) | def encode(self, src: str):
    method decode (line 645) | def decode(self, tokens):
    method printTokens (line 648) | def printTokens(self, tokens):

FILE: RWKV-v5/src/binidx.py
  function print_rank_0 (line 10) | def print_rank_0(*message):
  function _warmup_mmap_file (line 19) | def _warmup_mmap_file(path):
  function code (line 36) | def code(dtype):
  function index_file_path (line 42) | def index_file_path(prefix_path):
  function data_file_path (line 45) | def data_file_path(prefix_path):
  class MMapIndexedDataset (line 48) | class MMapIndexedDataset(torch.utils.data.Dataset):
    class Index (line 49) | class Index(object):
      method writer (line 53) | def writer(cls, path, dtype):
      method __init__ (line 104) | def __init__(self, path, skip_warmup=False):
      method __del__ (line 149) | def __del__(self):
      method dtype (line 154) | def dtype(self):
      method sizes (line 158) | def sizes(self):
      method doc_idx (line 162) | def doc_idx(self):
      method __getitem__ (line 166) | def __getitem__(self, i):
      method __len__ (line 169) | def __len__(self):
    method __init__ (line 172) | def __init__(self, path, skip_warmup=False):
    method __getstate__ (line 181) | def __getstate__(self):
    method __setstate__ (line 184) | def __setstate__(self, state):
    method _do_init (line 187) | def _do_init(self, path, skip_warmup):
    method __del__ (line 201) | def __del__(self):
    method __len__ (line 206) | def __len__(self):
    method __getitem__ (line 210) | def __getitem__(self, idx):
    method get (line 232) | def get(self, idx, offset=0, length=None):
    method sizes (line 248) | def sizes(self):
    method doc_idx (line 252) | def doc_idx(self):
    method get_doc_idx (line 255) | def get_doc_idx(self):
    method set_doc_idx (line 258) | def set_doc_idx(self, doc_idx_):
    method supports_prefetch (line 262) | def supports_prefetch(self):
    method exists (line 266) | def exists(path):

FILE: RWKV-v5/src/dataset.py
  class MyDataset (line 14) | class MyDataset(Dataset):
    method __init__ (line 15) | def __init__(self, args):
    method __len__ (line 99) | def __len__(self):
    method __getitem__ (line 102) | def __getitem__(self, idx):

FILE: RWKV-v5/src/model.py
  function __nop (line 25) | def __nop(ob):
  class WindBackstepping (line 50) | class WindBackstepping(torch.autograd.Function):
    method forward (line 52) | def forward(ctx, w,q,k,v,z,b):
    method backward (line 64) | def backward(ctx, dy):
  function RUN_CUDA_RWKV7g (line 72) | def RUN_CUDA_RWKV7g(q,w,k,v,a,b):
  class WKV_6STATE (line 82) | class WKV_6STATE(torch.autograd.Function):
    method forward (line 84) | def forward(ctx, B, T, C, H, r, k, v, w, u, s):
    method backward (line 109) | def backward(ctx, gy):
  function RUN_CUDA_RWKV6_STATE (line 129) | def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
  class WKV_6 (line 135) | class WKV_6(torch.autograd.Function):
    method forward (line 137) | def forward(ctx, r, k, v, w, u):
    method backward (line 162) | def backward(ctx, gy):
  function RUN_CUDA_RWKV6 (line 180) | def RUN_CUDA_RWKV6(r, k, v, w, u):
  class WKV_5 (line 187) | class WKV_5(torch.autograd.Function):
    method forward (line 189) | def forward(ctx, B, T, C, H, r, k, v, w, u):
    method backward (line 214) | def backward(ctx, gy):
  function RUN_CUDA_RWKV5 (line 233) | def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
  class RWKV_Tmix_x052 (line 241) | class RWKV_Tmix_x052(MyModule):
    method __init__ (line 242) | def __init__(self, args, layer_id):
    method jit_func (line 290) | def jit_func(self, x):
    method jit_func_2 (line 307) | def jit_func_2(self, x, g):
    method forward (line 315) | def forward(self, x):
  class RWKV_Tmix_x060 (line 325) | class RWKV_Tmix_x060(MyModule):
    method __init__ (line 326) | def __init__(self, args, layer_id):
    method forward (line 381) | def forward(self, x):
  class RWKV_Tmix_x060_state (line 414) | class RWKV_Tmix_x060_state(MyModule):
    method __init__ (line 415) | def __init__(self, args, layer_id):
    method jit_func (line 471) | def jit_func(self, x):
    method jit_func_2 (line 498) | def jit_func_2(self, x, g):
    method forward (line 506) | def forward(self, x):
  class RWKV_Tmix_x060a (line 517) | class RWKV_Tmix_x060a(MyModule):
    method __init__ (line 518) | def __init__(self, args, layer_id):
    method forward (line 576) | def forward(self, x):
  class RWKV_Tmix_x060b (line 609) | class RWKV_Tmix_x060b(MyModule):
    method __init__ (line 610) | def __init__(self, args, layer_id):
    method forward (line 658) | def forward(self, x):
  class RWKV_Tmix_x060c (line 686) | class RWKV_Tmix_x060c(MyModule):
    method __init__ (line 687) | def __init__(self, args, layer_id):
    method forward (line 735) | def forward(self, x):
  class RWKV_Tmix_x070 (line 766) | class RWKV_Tmix_x070(MyModule):
    method __init__ (line 767) | def __init__(self, args, layer_id):
    method forward (line 859) | def forward(self, x, v_first):
  class RWKV_CMix_x052 (line 895) | class RWKV_CMix_x052(MyModule):
    method __init__ (line 896) | def __init__(self, args, layer_id):
    method forward (line 915) | def forward(self, x):
  class RWKV_CMix_x060 (line 924) | class RWKV_CMix_x060(MyModule):
    method __init__ (line 925) | def __init__(self, args, layer_id):
    method forward (line 944) | def forward(self, x):
  class RWKV_CMix_x070 (line 954) | class RWKV_CMix_x070(MyModule):
    method __init__ (line 955) | def __init__(self, args, layer_id):
    method forward (line 976) | def forward(self, x):
  class MishGLU (line 986) | class MishGLU(MyModule):
    method __init__ (line 987) | def __init__(self, args, layer_id):
    method forward (line 1007) | def forward(self, x):
  class Block (line 1020) | class Block(nn.Module):
    method __init__ (line 1021) | def __init__(self, args, layer_id):
    method forward (line 1077) | def forward(self, x, v_first):
    method forward (line 1087) | def forward(self, x, x_emb=None):
  class L2Wrap (line 1119) | class L2Wrap(torch.autograd.Function):
    method forward (line 1121) | def forward(ctx, loss, y):
    method backward (line 1126) | def backward(ctx, grad_output):
  class RWKV (line 1136) | class RWKV(pl.LightningModule):
    method __init__ (line 1137) | def __init__(self, args):
    method configure_optimizers (line 1169) | def configure_optimizers(self):
    method deepspeed_offload (line 1251) | def deepspeed_offload(self) -> bool:
    method forward (line 1258) | def forward(self, idx):
    method training_step (line 1310) | def training_step(self, batch, batch_idx):
    method training_step_end (line 1351) | def training_step_end(self, batch_parts):
    method generate_init_weight (line 1357) | def generate_init_weight(self):

FILE: RWKV-v5/src/trainer.py
  function my_save (line 7) | def my_save(args, trainer, dd, ff):
  class train_callback (line 32) | class train_callback(pl.Callback):
    method __init__ (line 33) | def __init__(self, args):
    method on_train_batch_start (line 37) | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
    method on_train_batch_end (line 123) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
    method on_train_epoch_start (line 166) | def on_train_epoch_start(self, trainer, pl_module):
    method on_train_epoch_end (line 178) | def on_train_epoch_end(self, trainer, pl_module):
  function generate_init_weight (line 210) | def generate_init_weight(model, init_weight_name):

FILE: RWKV-v5/src/utils.py
  function record_time (line 9) | def record_time(name):
  class TOKENIZER (line 16) | class TOKENIZER():
    method __init__ (line 17) | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
    method refine_context (line 39) | def refine_context(self, context):
    method sample_logits (line 49) | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=...
  function MaybeIsPrime (line 84) | def MaybeIsPrime(number):
  function FermatPrimalityTest (line 91) | def FermatPrimalityTest(number):
  function MillerRabinPrimalityTest (line 102) | def MillerRabinPrimalityTest(number):

FILE: RWKV-v5/tokenizer/rwkv_tokenizer.py
  class TRIE (line 5) | class TRIE:
    method __init__ (line 9) | def __init__(self, front=None, ch=None):
    method __repr__ (line 15) | def __repr__(self):
    method add (line 24) | def add(self, key:bytes, idx:int=0, val=None):
    method find_longest (line 35) | def find_longest(self, key:bytes, idx:int=0):
  class TRIE_TOKENIZER (line 49) | class TRIE_TOKENIZER():
    method __init__ (line 50) | def __init__(self, file_name):
    method encodeBytes (line 72) | def encodeBytes(self, src:bytes):
    method decodeBytes (line 83) | def decodeBytes(self, tokens):
    method encode (line 86) | def encode(self, src):
    method decode (line 89) | def decode(self, tokens):
    method printTokens (line 95) | def printTokens(self, tokens):

FILE: RWKV-v7/cuda/wkv7_op.cpp
  function forward (line 9) | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &...
  function TORCH_LIBRARY (line 13) | TORCH_LIBRARY(wkv7, m) {

FILE: RWKV-v7/cuda/wkv7s_op.cpp
  function forward (line 9) | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &...
  function TORCH_LIBRARY (line 13) | TORCH_LIBRARY(wkv7s, m) {

FILE: RWKV-v7/rwkv_v7_demo.py
  class RWKV_TOKENIZER (line 53) | class RWKV_TOKENIZER():
    method __init__ (line 57) | def __init__(self, file_name):
    method encodeBytes (line 88) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 109) | def decodeBytes(self, tokens):
    method encode (line 112) | def encode(self, src: str):
    method decode (line 115) | def decode(self, tokens):
    method printTokens (line 118) | def printTokens(self, tokens):
  class WKV_7 (line 141) | class WKV_7(torch.autograd.Function):
    method forward (line 143) | def forward(ctx, r, w, k, v, a, b):
  function RWKV7_OP (line 165) | def RWKV7_OP(r, w, k, v, a, b):
  function RWKV7_OP (line 170) | def RWKV7_OP(r, w, k, v, a, b):
  class RWKV_Tmix_x070 (line 209) | class RWKV_Tmix_x070(MyModule):
    method __init__ (line 210) | def __init__(self, args, layer_id):
    method forward (line 257) | def forward(self, x, v_first):
  class RWKV_CMix_x070 (line 295) | class RWKV_CMix_x070(MyModule):
    method __init__ (line 296) | def __init__(self, args, layer_id):
    method forward (line 309) | def forward(self, x):
  class Block (line 320) | class Block(MyModule):
    method __init__ (line 321) | def __init__(self, args, layer_id):
    method forward (line 334) | def forward(self, x, v_first):
  class RWKV (line 349) | class RWKV(nn.Module):
    method __init__ (line 350) | def __init__(self, args):
    method forward (line 361) | def forward(self, idx):

FILE: RWKV-v7/rwkv_v7_demo_fast.py
  class WKV_7 (line 63) | class WKV_7(torch.autograd.Function):
    method forward (line 65) | def forward(ctx, state, r, w, k, v, a, b):
  function RWKV7_OP (line 76) | def RWKV7_OP(state, r, w, k, v, a, b):
  class RWKV_x070 (line 81) | class RWKV_x070(MyModule):
    method __init__ (line 82) | def __init__(self, args):
    method forward (line 106) | def forward(self, idx, state, full_output=False):
    method forward_one (line 123) | def forward_one(self, idx:int, state:List[torch.Tensor]):
    method forward_seq (line 154) | def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_ou...
  function RWKV_x070_TMix_one (line 188) | def RWKV_x070_TMix_one(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x070_TMix_seq (line 215) | def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x070_CMix_one (line 251) | def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_):
  function RWKV_x070_CMix_seq (line 258) | def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_):
  function sample_logits (line 271) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  class RWKV_TOKENIZER (line 299) | class RWKV_TOKENIZER():
    method __init__ (line 303) | def __init__(self, file_name):
    method encodeBytes (line 334) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 355) | def decodeBytes(self, tokens):
    method encode (line 358) | def encode(self, src: str):
    method decode (line 361) | def decode(self, tokens):
    method printTokens (line 364) | def printTokens(self, tokens):

FILE: RWKV-v7/rwkv_v7_demo_rnn.py
  class RWKV_RNN (line 51) | class RWKV_RNN(MyModule):
    method __init__ (line 52) | def __init__(self, args):
    method forward (line 79) | def forward(self, token:int, state:List[torch.Tensor]):
  function time_mixing__ (line 112) | def time_mixing__(layer_id:int, H:int, N:int, x, x_prev, v_first, state,...
  function channel_mixing__ (line 158) | def channel_mixing__(x, x_prev, x_k, kw, vw):
  function sample_logits (line 171) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  class RWKV_TOKENIZER (line 199) | class RWKV_TOKENIZER():
    method __init__ (line 203) | def __init__(self, file_name):
    method encodeBytes (line 234) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 255) | def decodeBytes(self, tokens):
    method encode (line 258) | def encode(self, src: str):
    method decode (line 261) | def decode(self, tokens):
    method printTokens (line 264) | def printTokens(self, tokens):

FILE: RWKV-v7/rwkv_v7_numpy.py
  function time_mixing (line 13) | def time_mixing(x, v0, last_x, S, params):
  function channel_mixing (line 46) | def channel_mixing(x, last_x, mix, Wk, Wv):
  function RWKV7 (line 52) | def RWKV7(params, token, state):

FILE: RWKV-v7/rwkv_v7a_demo.py
  class WKV_7 (line 65) | class WKV_7(torch.autograd.Function):
    method forward (line 67) | def forward(ctx, state, r, w, k, v, a, b):
  function RWKV7_OP (line 78) | def RWKV7_OP(state, r, w, k, v, a, b):
  class RWKV_x070 (line 83) | class RWKV_x070(MyModule):
    method __init__ (line 84) | def __init__(self, args):
    method forward (line 112) | def forward(self, idx, state, full_output=False):
    method forward_one (line 129) | def forward_one(self, idx:int, state:List[torch.Tensor]):
    method forward_seq (line 160) | def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_ou...
  function RWKV_x070_TMix_one (line 194) | def RWKV_x070_TMix_one(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x070_TMix_seq (line 221) | def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x070_CMix_one (line 257) | def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_, semb_, s1_, s2_, s0_):
  function RWKV_x070_CMix_seq (line 266) | def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_, semb_, s1_, s2_, s0_):
  function sample_logits (line 282) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  class RWKV_TOKENIZER (line 310) | class RWKV_TOKENIZER():
    method __init__ (line 314) | def __init__(self, file_name):
    method encodeBytes (line 345) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 366) | def decodeBytes(self, tokens):
    method encode (line 369) | def encode(self, src: str):
    method decode (line 372) | def decode(self, tokens):
    method printTokens (line 375) | def printTokens(self, tokens):

FILE: RWKV-v7/rwkv_v7b_demo.py
  class WKV_7 (line 65) | class WKV_7(torch.autograd.Function):
    method forward (line 67) | def forward(ctx, state, r, w, k, v, a, b):
  function RWKV7_OP (line 78) | def RWKV7_OP(state, r, w, k, v, a, b):
  class RWKV_x070 (line 83) | class RWKV_x070(MyModule):
    method __init__ (line 84) | def __init__(self, args):
    method forward (line 114) | def forward(self, idx, state, full_output=False):
    method forward_seq (line 138) | def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_ou...
  function RWKV_x070_TMix_seq (line 196) | def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x070_CMix_seq (line 232) | def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_, semb_, s1_, s2_, s0_):
  function sample_logits (line 248) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  class RWKV_TOKENIZER (line 276) | class RWKV_TOKENIZER():
    method __init__ (line 280) | def __init__(self, file_name):
    method encodeBytes (line 311) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 332) | def decodeBytes(self, tokens):
    method encode (line 335) | def encode(self, src: str):
    method decode (line 338) | def decode(self, tokens):
    method printTokens (line 341) | def printTokens(self, tokens):

FILE: RWKV-v7/rwkv_v8_rc00_demo.py
  class WKV_7 (line 67) | class WKV_7(torch.autograd.Function):
    method forward (line 69) | def forward(ctx, state, r, w, k, v, a, b):
  function RWKV7_OP (line 80) | def RWKV7_OP(state, r, w, k, v, a, b):
  class RWKV_x070 (line 85) | class RWKV_x070(MyModule):
    method __init__ (line 86) | def __init__(self, args):
    method forward (line 110) | def forward(self, idx, state, full_output=False):
    method forward_one (line 127) | def forward_one(self, idx:int, state:List[torch.Tensor]):
    method forward_seq (line 158) | def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_ou...
  function RWKV_x070_TMix_one (line 192) | def RWKV_x070_TMix_one(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x070_TMix_seq (line 219) | def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x080_CMix_one (line 255) | def RWKV_x080_CMix_one(x, x_prev, x_k, K_, V_, E_):
  function RWKV_x080_CMix_seq (line 262) | def RWKV_x080_CMix_seq(x, x_prev, x_k, K_, V_, E_):
  function sample_logits (line 275) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  class RWKV_TOKENIZER (line 304) | class RWKV_TOKENIZER():
    method __init__ (line 308) | def __init__(self, file_name):
    method encodeBytes (line 339) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 360) | def decodeBytes(self, tokens):
    method encode (line 363) | def encode(self, src: str):
    method decode (line 366) | def decode(self, tokens):
    method printTokens (line 369) | def printTokens(self, tokens):
    method __init__ (line 383) | def __init__(self):
    method encode (line 385) | def encode(self, x):
    method decode (line 387) | def decode(self, x):
  class RWKV_TOKENIZER (line 382) | class RWKV_TOKENIZER():
    method __init__ (line 308) | def __init__(self, file_name):
    method encodeBytes (line 339) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 360) | def decodeBytes(self, tokens):
    method encode (line 363) | def encode(self, src: str):
    method decode (line 366) | def decode(self, tokens):
    method printTokens (line 369) | def printTokens(self, tokens):
    method __init__ (line 383) | def __init__(self):
    method encode (line 385) | def encode(self, x):
    method decode (line 387) | def decode(self, x):

FILE: RWKV-v7/rwkv_v8_rc00_hybrid_demo.py
  class WKV_7 (line 67) | class WKV_7(torch.autograd.Function):
    method forward (line 69) | def forward(ctx, state, r, w, k, v, a, b):
  function RWKV7_OP (line 80) | def RWKV7_OP(state, r, w, k, v, a, b):
  class RWKV_x070 (line 85) | class RWKV_x070(MyModule):
    method __init__ (line 86) | def __init__(self, args):
    method forward (line 109) | def forward(self, idx, state, full_output=False):
    method forward_seq (line 134) | def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_ou...
  function RWKV_x070_TMix_seq (line 191) | def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, ...
  function RWKV_x080_CMix_seq (line 218) | def RWKV_x080_CMix_seq(x, x_prev, x_k, K_, V_, E_):
  function sample_logits (line 231) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  class RWKV_TOKENIZER (line 260) | class RWKV_TOKENIZER():
    method __init__ (line 264) | def __init__(self, file_name):
    method encodeBytes (line 295) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 316) | def decodeBytes(self, tokens):
    method encode (line 319) | def encode(self, src: str):
    method decode (line 322) | def decode(self, tokens):
    method printTokens (line 325) | def printTokens(self, tokens):
    method __init__ (line 339) | def __init__(self):
    method encode (line 341) | def encode(self, x):
    method decode (line 343) | def decode(self, x):
  class RWKV_TOKENIZER (line 338) | class RWKV_TOKENIZER():
    method __init__ (line 264) | def __init__(self, file_name):
    method encodeBytes (line 295) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 316) | def decodeBytes(self, tokens):
    method encode (line 319) | def encode(self, src: str):
    method decode (line 322) | def decode(self, tokens):
    method printTokens (line 325) | def printTokens(self, tokens):
    method __init__ (line 339) | def __init__(self):
    method encode (line 341) | def encode(self, x):
    method decode (line 343) | def decode(self, x):

FILE: RWKV-v7/train_temp/cuda/rwkv7_clampw.cpp
  function forward (line 12) | void forward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch...
  function backward (line 19) | void backward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torc...
  function TORCH_LIBRARY (line 26) | TORCH_LIBRARY(rwkv7_clampw, m) {

FILE: RWKV-v7/train_temp/cuda/wkv7_op.cpp
  function forward (line 7) | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch...
  function backward (line 14) | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torc...
  function TORCH_LIBRARY (line 21) | TORCH_LIBRARY(wind_backstepping, m) {
  function TORCH_LIBRARY_IMPL (line 26) | TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {

FILE: RWKV-v7/train_temp/cuda/wkv7_op_fp32.cpp
  function forward (line 7) | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch...
  function backward (line 14) | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torc...
  function TORCH_LIBRARY (line 21) | TORCH_LIBRARY(wind_backstepping, m) {
  function TORCH_LIBRARY_IMPL (line 26) | TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {

FILE: RWKV-v7/train_temp/rwkv7_train_simplified.py
  function set_seed_all (line 14) | def set_seed_all(seed):
  class WindBackstepping (line 39) | class WindBackstepping(torch.autograd.Function):
    method forward (line 41) | def forward(ctx, w,q,k,v,z,b):
    method backward (line 53) | def backward(ctx, dy):
  function RUN_CUDA_RWKV7g (line 60) | def RUN_CUDA_RWKV7g(q,w,k,v,a,b):
  class RWKV_Tmix_x070 (line 66) | class RWKV_Tmix_x070(MyModule):
    method __init__ (line 67) | def __init__(self, args, layer_id):
    method forward (line 147) | def forward(self, x, v_first):
  function _digits (line 186) | def _digits(n): return [TOK[c] for c in str(n)]
  function batch (line 188) | def batch(B,T, device=None):
  class FFN (line 200) | class FFN(nn.Module):
    method __init__ (line 201) | def __init__(self, C):
    method forward (line 210) | def forward(self, x):
  class MODEL (line 216) | class MODEL(nn.Module):
    method __init__ (line 217) | def __init__(s):
    method forward (line 243) | def forward(s,x):

FILE: RWKV-v7/train_temp/src/binidx.py
  function print_rank_0 (line 8) | def print_rank_0(*message):
  function _warmup_mmap_file (line 17) | def _warmup_mmap_file(path):
  function code (line 34) | def code(dtype):
  function index_file_path (line 40) | def index_file_path(prefix_path):
  function data_file_path (line 43) | def data_file_path(prefix_path):
  class MMapIndexedDataset (line 46) | class MMapIndexedDataset(torch.utils.data.Dataset):
    class Index (line 47) | class Index(object):
      method writer (line 51) | def writer(cls, path, dtype):
      method __init__ (line 102) | def __init__(self, path, skip_warmup=True):
      method __del__ (line 147) | def __del__(self):
      method dtype (line 152) | def dtype(self):
      method sizes (line 156) | def sizes(self):
      method doc_idx (line 160) | def doc_idx(self):
      method __getitem__ (line 164) | def __getitem__(self, i):
      method __len__ (line 167) | def __len__(self):
    method __init__ (line 170) | def __init__(self, path, skip_warmup=True):
    method __getstate__ (line 179) | def __getstate__(self):
    method __setstate__ (line 182) | def __setstate__(self, state):
    method _do_init (line 185) | def _do_init(self, path, skip_warmup=True):
    method __del__ (line 199) | def __del__(self):
    method __len__ (line 204) | def __len__(self):
    method __getitem__ (line 208) | def __getitem__(self, idx):
    method get (line 230) | def get(self, idx, offset=0, length=None):
    method sizes (line 246) | def sizes(self):
    method doc_idx (line 250) | def doc_idx(self):
    method get_doc_idx (line 253) | def get_doc_idx(self):
    method set_doc_idx (line 256) | def set_doc_idx(self, doc_idx_):
    method supports_prefetch (line 260) | def supports_prefetch(self):
    method exists (line 264) | def exists(path):

FILE: RWKV-v7/train_temp/src/dataset.py
  function is_prime (line 12) | def is_prime(n):
  class MyDataset (line 26) | class MyDataset(Dataset):
    method __init__ (line 27) | def __init__(self, args):
    method __len__ (line 46) | def __len__(self):
    method __getitem__ (line 49) | def __getitem__(self, idx):

FILE: RWKV-v7/train_temp/src/model.py
  function __nop (line 21) | def __nop(ob):
  class RWKV7_CLAMPW_CUDA_OP (line 48) | class RWKV7_CLAMPW_CUDA_OP(torch.autograd.Function):
    method forward (line 50) | def forward(ctx,r,w,k,v,a,b):
    method backward (line 62) | def backward(ctx,dy):
  function RWKV7_CLAMPW_CUDA (line 69) | def RWKV7_CLAMPW_CUDA(r,w,k,v,a,b):
  class RWKV_Tmix_x070 (line 76) | class RWKV_Tmix_x070(MyModule):
    method __init__ (line 77) | def __init__(self, args, layer_id):
    method forward (line 164) | def forward(self, x, v_first):
  class RWKV_CMix_x070 (line 200) | class RWKV_CMix_x070(MyModule):
    method __init__ (line 201) | def __init__(self, args, layer_id):
    method forward (line 221) | def forward(self, x):
  class Block (line 234) | class Block(nn.Module):
    method __init__ (line 235) | def __init__(self, args, layer_id):
    method forward (line 249) | def forward(self, x, v_first):
  class L2Wrap (line 260) | class L2Wrap(torch.autograd.Function):
    method forward (line 262) | def forward(ctx, loss, y):
    method backward (line 267) | def backward(ctx, grad_output):
  class RWKV (line 277) | class RWKV(pl.LightningModule):
    method __init__ (line 278) | def __init__(self, args):
    method configure_optimizers (line 296) | def configure_optimizers(self):
    method deepspeed_offload (line 337) | def deepspeed_offload(self) -> bool:
    method forward (line 344) | def forward(self, idx):
    method training_step (line 362) | def training_step(self, batch, batch_idx):
    method training_step_end (line 368) | def training_step_end(self, batch_parts):
    method generate_init_weight (line 373) | def generate_init_weight(self):

FILE: RWKV-v7/train_temp/src/trainer.py
  function my_save (line 7) | def my_save(args, trainer, dd, ff):
  class train_callback (line 13) | class train_callback(pl.Callback):
    method __init__ (line 14) | def __init__(self, args):
    method on_train_batch_start (line 18) | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
    method on_train_batch_end (line 81) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
    method on_train_epoch_start (line 121) | def on_train_epoch_start(self, trainer, pl_module):
    method on_train_epoch_end (line 130) | def on_train_epoch_end(self, trainer, pl_module):
  function generate_init_weight (line 159) | def generate_init_weight(model, init_weight_name):

FILE: RWKV-v8/251014_rosa_1bit_layer.py
  function rosa (line 6) | def rosa(x):
  class ROSA_1bit (line 27) | class ROSA_1bit(torch.autograd.Function): # !!! extremely slow !!!
    method forward (line 29) | def forward(ctx, x, emb0, emb1, tau: float):
    method backward (line 42) | def backward(ctx, gy):
  class ROSA_1bit_LAYER (line 86) | class ROSA_1bit_LAYER(nn.Module): # !!! extremely slow !!!
    method __init__ (line 87) | def __init__(self, C: int, tau: float = 1e-3):
    method forward (line 92) | def forward(self, x: torch.Tensor) -> torch.Tensor:

FILE: RWKV-v8/251014_rosa_1bit_train.py
  function rosa (line 9) | def rosa(x):
  function rosa_torch (line 30) | def rosa_torch(z: torch.Tensor) -> torch.Tensor:
  class Emb_ROSA (line 35) | class Emb_ROSA(nn.Module):
    method __init__ (line 36) | def __init__(s,V,C):
    method forward (line 39) | def forward(s,idx):
  class ROSA_1bit (line 46) | class ROSA_1bit(torch.autograd.Function): # !!! extremely slow !!!
    method forward (line 48) | def forward(ctx, x, emb0, emb1, tau: float):
    method backward (line 61) | def backward(ctx, gy):
  class ROSA_1bit_LAYER (line 106) | class ROSA_1bit_LAYER(nn.Module): # !!! extremely slow !!!
    method __init__ (line 107) | def __init__(self, C: int, tau: float = 1e-3):
    method forward (line 112) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function batch (line 122) | def batch(B,T,nn=None):
  class MODEL (line 136) | class MODEL(nn.Module):
    method __init__ (line 137) | def __init__(s):
    method forward (line 143) | def forward(s,x):

FILE: RWKV-v8/251014_rosa_onlyemb_train.py
  function rosa (line 9) | def rosa(x):
  function rosa_torch (line 30) | def rosa_torch(z: torch.Tensor) -> torch.Tensor:
  class Emb_ROSA (line 35) | class Emb_ROSA(nn.Module):
    method __init__ (line 36) | def __init__(s,V,C):
    method forward (line 39) | def forward(s,idx):
  function batch (line 51) | def batch(B,T,nn=None):
  class MODEL (line 65) | class MODEL(nn.Module):
    method __init__ (line 66) | def __init__(s):
    method forward (line 71) | def forward(s,x):

FILE: RWKV-v8/251016_rosa_1bit_run.py
  function rosa (line 9) | def rosa(x):
  function rosa_torch (line 30) | def rosa_torch(z: torch.Tensor) -> torch.Tensor:
  class Emb_ROSA (line 35) | class Emb_ROSA(nn.Module):
    method __init__ (line 36) | def __init__(s,V,C):
    method forward (line 39) | def forward(s,idx):
  class ROSA_1bit (line 46) | class ROSA_1bit(torch.autograd.Function):
    method forward (line 48) | def forward(ctx, x, emb0, emb1, tau: float):
  class ROSA_1bit_LAYER (line 67) | class ROSA_1bit_LAYER(nn.Module):
    method __init__ (line 68) | def __init__(self, C: int, tau: float = 1e-3):
    method forward (line 73) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function batch (line 83) | def batch(B,T,nn=None):
  class MODEL (line 97) | class MODEL(nn.Module):
    method __init__ (line 98) | def __init__(s):
    method forward (line 106) | def forward(s,x):

FILE: RWKV-v8/251018_rosa_4bit_run.py
  function rosa (line 9) | def rosa(x):
  function rosa_batch_python_orig (line 30) | def rosa_batch_python_orig(z: torch.Tensor) -> torch.Tensor:
  function rosa_batch_python (line 35) | def rosa_batch_python(z: torch.Tensor) -> torch.Tensor:
  class rosa_emb_layer (line 40) | class rosa_emb_layer(nn.Module):
    method __init__ (line 41) | def __init__(s,V,C):
    method forward (line 44) | def forward(s,idx):
  class rosa_4bit_layer (line 49) | class rosa_4bit_layer(nn.Module):
    method __init__ (line 50) | def __init__(self, C: int, eps: float = 1e-5):
    method forward (line 55) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function batch (line 77) | def batch(B,T,nn=None):
  class MODEL (line 91) | class MODEL(nn.Module):
    method __init__ (line 92) | def __init__(s):
    method forward (line 104) | def forward(s,x):

FILE: RWKV-v8/251024_rosaQKV_run.py
  function set_seed_all (line 9) | def set_seed_all(seed):
  function samx_qkv_slow (line 29) | def samx_qkv_slow(qqq, kkk, vvv): # slow, only for reference
  function samx_qkv_batch_ref (line 50) | def samx_qkv_batch_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor...
  class samx_qkv_1bit_layer_op (line 62) | class samx_qkv_1bit_layer_op(torch.autograd.Function):
    method forward (line 64) | def forward(ctx, q, k, v, e):
  class samx_qkv_1bit_layer (line 72) | class samx_qkv_1bit_layer(nn.Module):
    method __init__ (line 73) | def __init__(self, C: int):
    method forward (line 76) | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -...
  class ROSA_QKV_B_1bit (line 79) | class ROSA_QKV_B_1bit(nn.Module):
    method __init__ (line 80) | def __init__(s,C):
    method forward (line 91) | def forward(s,x):
  class WindBackstepping (line 106) | class WindBackstepping(torch.autograd.Function):
    method forward (line 108) | def forward(ctx, w,q,k,v,z,b):
    method backward (line 120) | def backward(ctx, dy):
  function RUN_CUDA_RWKV7g (line 127) | def RUN_CUDA_RWKV7g(q,w,k,v,a,b):
  class RWKV_Tmix_x070 (line 132) | class RWKV_Tmix_x070(MyModule):
    method __init__ (line 133) | def __init__(self, args, layer_id):
    method forward (line 210) | def forward(self, x, v_first):
  class FFN (line 246) | class FFN(nn.Module):
    method __init__ (line 247) | def __init__(self, C):
    method forward (line 256) | def forward(self, x):
  class MODEL (line 263) | class MODEL(nn.Module):
    method __init__ (line 264) | def __init__(s):
    method forward (line 306) | def forward(s,x):
    method __init__ (line 331) | def __init__(s):
    method forward (line 359) | def forward(s,x):
  class MODEL (line 330) | class MODEL(nn.Module):
    method __init__ (line 264) | def __init__(s):
    method forward (line 306) | def forward(s,x):
    method __init__ (line 331) | def __init__(s):
    method forward (line 359) | def forward(s,x):
  function get_randint (line 378) | def get_randint(digits):

FILE: RWKV-v8/251105_reverse_run.py
  function set_seed_all (line 9) | def set_seed_all(seed):
  function samx_qkv_slow (line 41) | def samx_qkv_slow(qqq, kkk, vvv): # slow, only for reference
  function samx_qkv_batch_ref (line 62) | def samx_qkv_batch_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor...
  class samx_qkv_1bit_layer_op (line 74) | class samx_qkv_1bit_layer_op(torch.autograd.Function):
    method forward (line 76) | def forward(ctx, q, k, v, e):
  class samx_qkv_1bit_layer (line 84) | class samx_qkv_1bit_layer(nn.Module):
    method __init__ (line 85) | def __init__(self, C: int):
    method forward (line 88) | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -...
  class ROSA_QKV_B_1bit (line 91) | class ROSA_QKV_B_1bit(nn.Module):
    method __init__ (line 92) | def __init__(s,C):
    method forward (line 103) | def forward(s,x):
  class WindBackstepping (line 118) | class WindBackstepping(torch.autograd.Function):
    method forward (line 120) | def forward(ctx, w,q,k,v,z,b):
    method backward (line 132) | def backward(ctx, dy):
  function RUN_CUDA_RWKV7g (line 139) | def RUN_CUDA_RWKV7g(q,w,k,v,a,b):
  class RWKV_Tmix_x070 (line 144) | class RWKV_Tmix_x070(MyModule):
    method __init__ (line 145) | def __init__(self, args, layer_id):
    method forward (line 222) | def forward(self, x, v_first):
  class FFN (line 258) | class FFN(nn.Module):
    method __init__ (line 259) | def __init__(self, C):
    method forward (line 268) | def forward(self, x):
  class MODEL (line 275) | class MODEL(nn.Module):
    method __init__ (line 276) | def __init__(s):
    method forward (line 318) | def forward(s,x):
    method __init__ (line 343) | def __init__(s):
    method forward (line 371) | def forward(s,x):
  class MODEL (line 342) | class MODEL(nn.Module):
    method __init__ (line 276) | def __init__(s):
    method forward (line 318) | def forward(s,x):
    method __init__ (line 343) | def __init__(s):
    method forward (line 371) | def forward(s,x):
  function get_randint (line 390) | def get_randint(digits):

FILE: RWKV-v8/260212_rosa1bitLM_L12.py
  class RWKV_TOKENIZER (line 43) | class RWKV_TOKENIZER():
    method __init__ (line 47) | def __init__(self, file_name):
    method encodeBytes (line 78) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 99) | def decodeBytes(self, tokens):
    method encode (line 102) | def encode(self, src: str):
    method decode (line 105) | def decode(self, tokens):
    method printTokens (line 108) | def printTokens(self, tokens):
  function sample_logits (line 122) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  function rosa_qkv_ref (line 150) | def rosa_qkv_ref(qqq, kkk, vvv):
  function rosa_qkv_batch_ref (line 171) | def rosa_qkv_batch_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor...
  class rosa_qkv_1bit_layer_op (line 183) | class rosa_qkv_1bit_layer_op(torch.autograd.Function):
    method forward (line 185) | def forward(ctx, q, k, v, e):
  class rosa_qkv_1bit_layer (line 193) | class rosa_qkv_1bit_layer(nn.Module):
    method __init__ (line 194) | def __init__(self, C: int):
    method forward (line 197) | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -...
  class RWKV_ROSA_1bit (line 200) | class RWKV_ROSA_1bit(nn.Module):
    method __init__ (line 201) | def __init__(s,C):
    method forward (line 212) | def forward(s,x):
  class RWKV_CMix_x070 (line 224) | class RWKV_CMix_x070(MyModule):
    method __init__ (line 225) | def __init__(self, args, layer_id):
    method forward (line 238) | def forward(self, x):
  class Block (line 249) | class Block(MyModule):
    method __init__ (line 250) | def __init__(self, args, layer_id):
    method forward (line 263) | def forward(self, x, v_first):
  class RWKV (line 277) | class RWKV(nn.Module):
    method __init__ (line 278) | def __init__(self, args):
    method forward (line 289) | def forward(self, idx):

FILE: RWKV-v8/260222_rosa4bitLM_L12.py
  function __nop (line 39) | def __nop(ob):
  class RWKV_TOKENIZER (line 49) | class RWKV_TOKENIZER():
    method __init__ (line 53) | def __init__(self, file_name):
    method encodeBytes (line 84) | def encodeBytes(self, src: bytes) -> list[int]:
    method decodeBytes (line 105) | def decodeBytes(self, tokens):
    method encode (line 108) | def encode(self, src: str):
    method decode (line 111) | def decode(self, tokens):
    method printTokens (line 114) | def printTokens(self, tokens):
  function sample_logits (line 128) | def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:...
  function rosa_slow_ref (line 156) | def rosa_slow_ref(q, k, v):
  class rosa_slow_4bit_layer (line 175) | class rosa_slow_4bit_layer(nn.Module): # !!! matched 1 => e, matched 0 =...
    method __init__ (line 176) | def __init__(self, C):
    method forward (line 179) | def forward(self, q, k, v):
  class RWKV_ROSA_4bit (line 215) | class RWKV_ROSA_4bit(nn.Module):
    method __init__ (line 216) | def __init__(s,C):
    method forward (line 227) | def forward(s,x):
  class RWKV_CMix_x070 (line 239) | class RWKV_CMix_x070(MyModule):
    method __init__ (line 240) | def __init__(self, args, layer_id):
    method forward (line 253) | def forward(self, x):
  class Block (line 264) | class Block(MyModule):
    method __init__ (line 265) | def __init__(self, args, layer_id):
    method forward (line 278) | def forward(self, x, v_first):
  class RWKV (line 292) | class RWKV(nn.Module):
    method __init__ (line 293) | def __init__(self, args):
    method forward (line 304) | def forward(self, idx):

FILE: RWKV-v8/cuda/wkv7_op.cpp
  function forward (line 7) | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch...
  function backward (line 14) | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torc...
  function TORCH_LIBRARY (line 21) | TORCH_LIBRARY(wind_backstepping, m) {
  function TORCH_LIBRARY_IMPL (line 26) | TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
Condensed preview — 139 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (8,965K chars).
[
  {
    "path": ".github/FUNDING.yml",
    "chars": 15,
    "preview": "ko_fi: rwkv_lm\n"
  },
  {
    "path": ".gitignore",
    "chars": 1881,
    "preview": "*.txt\n*.csv\n*.pth\n*.xlsb\n*.xlsx\n*.xls\nwandb/\ndata/\nvocab.json\n*log/\ntest/\ntools/\n\n# Byte-compiled / optimized / DLL file"
  },
  {
    "path": "CITATION.cff",
    "chars": 310,
    "preview": "cff-version: 1.2.0\nmessage: \"If you use this software, please cite it as below.\"\nauthors:\n- family-names: \"PENG\"\n  given"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 59107,
    "preview": "# RWKV: Parallelizable RNN with Transformer-level LLM Performance (pronounced as \"RwaKuv\" (rʌkuv in IPA), from 4 major p"
  },
  {
    "path": "RWKV-8.md",
    "chars": 3250,
    "preview": "# Improving RNNs (RWKV-8 and beyond)\n\nHere I will show a framework to improve current RNNs.\n\n## 1. Larger State\n\nThis in"
  },
  {
    "path": "RWKV-v1/src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "RWKV-v1/src/model.py",
    "chars": 22021,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v1/src/trainer.py",
    "chars": 6197,
    "preview": "import math, sys, datetime\nimport logging\nimport numpy as np\nfrom tqdm.auto import tqdm\nimport torch\nimport torch.optim "
  },
  {
    "path": "RWKV-v1/src/utils.py",
    "chars": 1513,
    "preview": "import random\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\ndef top_k_logi"
  },
  {
    "path": "RWKV-v1/train.py",
    "chars": 6908,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v2-RNN/cuda/timex_cuda.cu",
    "chars": 5633,
    "preview": "#include <stdio.h>\n\n// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compil"
  },
  {
    "path": "RWKV-v2-RNN/cuda/timex_op.cpp",
    "chars": 1034,
    "preview": "#include <torch/extension.h>\n\nvoid cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T"
  },
  {
    "path": "RWKV-v2-RNN/run.py",
    "chars": 4784,
    "preview": "# -*- coding:utf-8 -*-\n#################################################################################################"
  },
  {
    "path": "RWKV-v2-RNN/src/model.py",
    "chars": 13280,
    "preview": "########################################################################################################\n# The RWKV v2-R"
  },
  {
    "path": "RWKV-v2-RNN/src/model_run.py",
    "chars": 4598,
    "preview": "import types\nimport copy\nimport torch\nfrom torch.nn import functional as F\n\nRWKV_K_CLAMP = 60\nRWKV_K_EPS = 1e-16\nRWKV_HE"
  },
  {
    "path": "RWKV-v2-RNN/src/trainer.py",
    "chars": 7098,
    "preview": "########################################################################################################\n# The RWKV v2-R"
  },
  {
    "path": "RWKV-v2-RNN/src/utils.py",
    "chars": 4082,
    "preview": "########################################################################################################\n# The RWKV v2-R"
  },
  {
    "path": "RWKV-v2-RNN/train.py",
    "chars": 4176,
    "preview": "########################################################################################################\n# The RWKV v2-R"
  },
  {
    "path": "RWKV-v3/cuda/timex_cuda.cu",
    "chars": 5633,
    "preview": "#include <stdio.h>\n\n// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compil"
  },
  {
    "path": "RWKV-v3/cuda/timex_op.cpp",
    "chars": 1034,
    "preview": "#include <torch/extension.h>\n\nvoid cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T"
  },
  {
    "path": "RWKV-v3/run.py",
    "chars": 3542,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v3/src/model.py",
    "chars": 14277,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v3/src/model_run.py",
    "chars": 11329,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v3/src/trainer.py",
    "chars": 7082,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v3/src/utils.py",
    "chars": 4075,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v3/train.py",
    "chars": 5576,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v3/verify.py",
    "chars": 2359,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4/20B_tokenizer.json",
    "chars": 2467981,
    "preview": "{\n  \"version\": \"1.0\",\n  \"truncation\": null,\n  \"padding\": null,\n  \"added_tokens\": [\n    {\n      \"id\": 0,\n      \"special\":"
  },
  {
    "path": "RWKV-v4/cuda/wkv_cuda.cu",
    "chars": 4348,
    "preview": "#include <stdio.h>\n#include <assert.h>\n\n#define MIN_VALUE (-1e38)\n\ntemplate <typename F>\n__global__ void kernel_forward("
  },
  {
    "path": "RWKV-v4/cuda/wkv_op.cpp",
    "chars": 1203,
    "preview": "#include <torch/extension.h>\n\nvoid cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);\n"
  },
  {
    "path": "RWKV-v4/run.py",
    "chars": 5528,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4/src/binidx.py",
    "chars": 6647,
    "preview": "from lib2to3.pgen2 import token\nimport os\nimport torch\nimport numpy as np\nimport shutil\nimport struct\nfrom functools imp"
  },
  {
    "path": "RWKV-v4/src/model.py",
    "chars": 15730,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4/src/model_run.py",
    "chars": 14599,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4/src/trainer.py",
    "chars": 8055,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4/src/utils.py",
    "chars": 5626,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4/train.py",
    "chars": 12277,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4/verify.py",
    "chars": 3441,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/20B_tokenizer.json",
    "chars": 2467981,
    "preview": "{\n  \"version\": \"1.0\",\n  \"truncation\": null,\n  \"padding\": null,\n  \"added_tokens\": [\n    {\n      \"id\": 0,\n      \"special\":"
  },
  {
    "path": "RWKV-v4neo/chat.py",
    "chars": 12062,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/cuda/wkv5_cuda.cu",
    "chars": 5658,
    "preview": "#include <stdio.h>\n#include <assert.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\ntemplate <typename F>\n__global"
  },
  {
    "path": "RWKV-v4neo/cuda/wkv5_op.cpp",
    "chars": 1469,
    "preview": "#include <torch/extension.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\nvoid cuda_forward(int B, int T, int C, i"
  },
  {
    "path": "RWKV-v4neo/cuda/wkv_cuda.cu",
    "chars": 4604,
    "preview": "#include <stdio.h>\n#include <assert.h>\n\n#define MIN_VALUE (-1e38)\n\ntemplate <typename F>\n__global__ void kernel_forward("
  },
  {
    "path": "RWKV-v4neo/cuda/wkv_cuda_bf16.cu",
    "chars": 4901,
    "preview": "#include <stdio.h>\n#include <assert.h>\n#include \"ATen/ATen.h\"\n#define MIN_VALUE (-1e38)\ntypedef at::BFloat16 bf16;\n\n__gl"
  },
  {
    "path": "RWKV-v4neo/cuda/wkv_op.cpp",
    "chars": 1252,
    "preview": "#include <torch/extension.h>\n\nvoid cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);\n"
  },
  {
    "path": "RWKV-v4neo/cuda/wkv_op_bf16.cpp",
    "chars": 1288,
    "preview": "#include <torch/extension.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\nvoid cuda_forward(int B, int T, int C, f"
  },
  {
    "path": "RWKV-v4neo/img_demoAE.py",
    "chars": 6185,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/math_demo/run.py",
    "chars": 6006,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/run.py",
    "chars": 7845,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "RWKV-v4neo/src/binidx.py",
    "chars": 8648,
    "preview": "from lib2to3.pgen2 import token\nimport os\nimport torch\nimport numpy as np\nimport shutil\nimport struct\nfrom functools imp"
  },
  {
    "path": "RWKV-v4neo/src/dataset.py",
    "chars": 11788,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/src/model.py",
    "chars": 42739,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/src/model_img.py",
    "chars": 16780,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/src/model_run.py",
    "chars": 8874,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/src/trainer.py",
    "chars": 10878,
    "preview": "import os, math, time, datetime, subprocess\nimport torch\nfrom torch.utils.data import DataLoader\nimport pytorch_lightnin"
  },
  {
    "path": "RWKV-v4neo/src/utils.py",
    "chars": 4587,
    "preview": "import json, time, random, os\nimport numpy as np\nimport torch\nfrom torch.nn import functional as F\n\ntime_slot = {}\ntime_"
  },
  {
    "path": "RWKV-v4neo/train.py",
    "chars": 19396,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v4neo/verify.py",
    "chars": 3525,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v5/compute_magic_prime.py",
    "chars": 1145,
    "preview": "import json, math, random, sys, time, shutil, os, string, re, fileinput\nimport numpy as np\nfrom src.binidx import MMapIn"
  },
  {
    "path": "RWKV-v5/cuda/wkv5_cuda.cu",
    "chars": 5658,
    "preview": "#include <stdio.h>\n#include <assert.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\ntemplate <typename F>\n__global"
  },
  {
    "path": "RWKV-v5/cuda/wkv5_op.cpp",
    "chars": 1469,
    "preview": "#include <torch/extension.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\nvoid cuda_forward(int B, int T, int C, i"
  },
  {
    "path": "RWKV-v5/cuda/wkv6_cuda.cu",
    "chars": 8169,
    "preview": "#include <stdio.h>\n#include <assert.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\ntemplate <typename F>\n__global"
  },
  {
    "path": "RWKV-v5/cuda/wkv6_op.cpp",
    "chars": 1269,
    "preview": "#include <torch/extension.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\nvoid cuda_forward(int B, int T, int C, i"
  },
  {
    "path": "RWKV-v5/cuda/wkv6state_cuda.cu",
    "chars": 8832,
    "preview": "#include <stdio.h>\n#include <assert.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\ntemplate <typename F>\n__global"
  },
  {
    "path": "RWKV-v5/cuda/wkv6state_op.cpp",
    "chars": 1572,
    "preview": "#include <torch/extension.h>\n#include \"ATen/ATen.h\"\ntypedef at::BFloat16 bf16;\n\nvoid cuda_forward(int B, int T, int C, i"
  },
  {
    "path": "RWKV-v5/cuda/wkv7_cuda.cu",
    "chars": 4421,
    "preview": "#include <cuda_bf16.h>\n#include <assert.h>\n\nusing bf = __nv_bfloat16;\n__device__ inline float to_float(const bf & u) { r"
  },
  {
    "path": "RWKV-v5/cuda/wkv7_op.cpp",
    "chars": 1984,
    "preview": "#include <torch/extension.h>\n#include <cuda_bf16.h>\nusing bf = __nv_bfloat16;\n\nvoid cuda_forward(int B, int T, int H, bf"
  },
  {
    "path": "RWKV-v5/demo-training-prepare-v7-pile.sh",
    "chars": 1455,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v5/demo-training-prepare.sh",
    "chars": 2085,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v5/demo-training-run-v7-pile.sh",
    "chars": 2351,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v5/demo-training-run.sh",
    "chars": 2844,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v5/demo.jsonl",
    "chars": 299623,
    "preview": "{\"text\": \"System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\\n\\nUs"
  },
  {
    "path": "RWKV-v5/make_data.py",
    "chars": 5255,
    "preview": "import json, math, random, sys, time, shutil, os, string, re, fileinput\nimport numpy as np\n\n\"\"\"\nHow to use:\n\npython make"
  },
  {
    "path": "RWKV-v5/rwkv_v6_demo.py",
    "chars": 26296,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v5/src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "RWKV-v5/src/binidx.py",
    "chars": 8648,
    "preview": "from lib2to3.pgen2 import token\nimport os\nimport torch\nimport numpy as np\nimport shutil\nimport struct\nfrom functools imp"
  },
  {
    "path": "RWKV-v5/src/dataset.py",
    "chars": 9247,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v5/src/model.py",
    "chars": 64765,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v5/src/trainer.py",
    "chars": 11130,
    "preview": "import os, math, time, datetime, subprocess\nimport torch\nfrom torch.utils.data import DataLoader\nimport pytorch_lightnin"
  },
  {
    "path": "RWKV-v5/src/utils.py",
    "chars": 4587,
    "preview": "import json, time, random, os\nimport numpy as np\nimport torch\nfrom torch.nn import functional as F\n\ntime_slot = {}\ntime_"
  },
  {
    "path": "RWKV-v5/tokenizer/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "RWKV-v5/tokenizer/rwkv_tokenizer.py",
    "chars": 3201,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v5/train.py",
    "chars": 16370,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v6/README.md",
    "chars": 103,
    "preview": "# Please use /RWKV-v5/ and add --my_testing \"x060\" as an extra train.py parameter, to enable RWKV v6.0\n"
  },
  {
    "path": "RWKV-v7/README.md",
    "chars": 475,
    "preview": "https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v7/train_temp# RWKV-7 \"Goose\" x070 (final release, \"rc4a\")\n\n**Train RW"
  },
  {
    "path": "RWKV-v7/cuda/wkv7.cu",
    "chars": 1631,
    "preview": "#include <stdio.h>\n#include <assert.h>\n#include \"ATen/ATen.h\"\n\ntypedef at::Half bf16;\n// typedef at::BFloat16 bf16;\n\ntem"
  },
  {
    "path": "RWKV-v7/cuda/wkv7_op.cpp",
    "chars": 634,
    "preview": "#include <torch/extension.h>\n#include \"ATen/ATen.h\"\n\ntypedef at::Half bf16;\n// typedef at::BFloat16 bf16;\n\nvoid cuda_for"
  },
  {
    "path": "RWKV-v7/cuda/wkv7s.cu",
    "chars": 1936,
    "preview": "#include <stdio.h>\n#include <assert.h>\n#include \"ATen/ATen.h\"\n\ntypedef at::Half bf16;\n// typedef at::BFloat16 bf16;\n\ntem"
  },
  {
    "path": "RWKV-v7/cuda/wkv7s_op.cpp",
    "chars": 696,
    "preview": "#include <torch/extension.h>\n#include \"ATen/ATen.h\"\n\ntypedef at::Half bf16;\n// typedef at::BFloat16 bf16;\n\nvoid cuda_for"
  },
  {
    "path": "RWKV-v7/misc/lambada_test.jsonl",
    "chars": 1819752,
    "preview": "{\"text\": \"In my palm is a clear stone, and inside it is a small ivory statuette. A guardian angel.\\n\\n\\\"Figured if you'r"
  },
  {
    "path": "RWKV-v7/mmlu_dev_dataset/dataset_info.json",
    "chars": 2024,
    "preview": "{\n  \"builder_name\": \"parquet\",\n  \"citation\": \"\",\n  \"config_name\": \"all\",\n  \"dataset_name\": \"mmlu\",\n  \"dataset_size\": 168"
  },
  {
    "path": "RWKV-v7/mmlu_dev_dataset/state.json",
    "chars": 248,
    "preview": "{\n  \"_data_files\": [\n    {\n      \"filename\": \"data-00000-of-00001.arrow\"\n    }\n  ],\n  \"_fingerprint\": \"ca7a71e4c243f30b\""
  },
  {
    "path": "RWKV-v7/mmlu_test_dataset/dataset_info.json",
    "chars": 2024,
    "preview": "{\n  \"builder_name\": \"parquet\",\n  \"citation\": \"\",\n  \"config_name\": \"all\",\n  \"dataset_name\": \"mmlu\",\n  \"dataset_size\": 168"
  },
  {
    "path": "RWKV-v7/mmlu_test_dataset/state.json",
    "chars": 249,
    "preview": "{\n  \"_data_files\": [\n    {\n      \"filename\": \"data-00000-of-00001.arrow\"\n    }\n  ],\n  \"_fingerprint\": \"436299c1c09696bb\""
  },
  {
    "path": "RWKV-v7/rwkv_mmlu_eval.py",
    "chars": 4331,
    "preview": "########################################################################################################\r\n# The RWKV Lan"
  },
  {
    "path": "RWKV-v7/rwkv_v7_demo.py",
    "chars": 15698,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/rwkv_v7_demo_fast.py",
    "chars": 18571,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/rwkv_v7_demo_rnn.py",
    "chars": 14370,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/rwkv_v7_numpy.py",
    "chars": 4049,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/rwkv_v7a_demo.py",
    "chars": 19310,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/rwkv_v7b_demo.py",
    "chars": 18956,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/rwkv_v8_rc00_demo.py",
    "chars": 19464,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/rwkv_v8_rc00_hybrid_demo.py",
    "chars": 18340,
    "preview": "########################################################################################################\r\n# The RWKV Lan"
  },
  {
    "path": "RWKV-v7/train_temp/README.md",
    "chars": 6028,
    "preview": "## HOW TO TRAIN RWKV-7 on MiniPile (1.5G tokens) ##\n\n**Simplified RWKV-7 training demo**: https://github.com/BlinkDL/RWK"
  },
  {
    "path": "RWKV-v7/train_temp/cuda/rwkv7_clampw.cpp",
    "chars": 1639,
    "preview": "#include <torch/extension.h>\n\n#ifdef _FP32_\n    using bf = float;\n#else\n    #include <cuda_bf16.h>\n    using bf = __nv_b"
  },
  {
    "path": "RWKV-v7/train_temp/cuda/rwkv7_clampw.cu",
    "chars": 5356,
    "preview": "#include <assert.h>\n\n#ifdef _FP32_\n    using bf = float;\n    #define to_float(u) (u)\n    #define to_bf(u) (u)\n#else\n    "
  },
  {
    "path": "RWKV-v7/train_temp/cuda/wkv7_cuda.cu",
    "chars": 4421,
    "preview": "#include <cuda_bf16.h>\n#include <assert.h>\n\nusing bf = __nv_bfloat16;\n__device__ inline float to_float(const bf & u) { r"
  },
  {
    "path": "RWKV-v7/train_temp/cuda/wkv7_cuda_fp32.cu",
    "chars": 4374,
    "preview": "#include <cuda_bf16.h>\n#include <assert.h>\n\nusing bf = float;\n__device__ inline float to_float(const bf & u) { return u;"
  },
  {
    "path": "RWKV-v7/train_temp/cuda/wkv7_op.cpp",
    "chars": 1984,
    "preview": "#include <torch/extension.h>\n#include <cuda_bf16.h>\nusing bf = __nv_bfloat16;\n\nvoid cuda_forward(int B, int T, int H, bf"
  },
  {
    "path": "RWKV-v7/train_temp/cuda/wkv7_op_fp32.cpp",
    "chars": 1976,
    "preview": "#include <torch/extension.h>\n#include <cuda_bf16.h>\nusing bf = float;\n\nvoid cuda_forward(int B, int T, int H, bf*w, bf*q"
  },
  {
    "path": "RWKV-v7/train_temp/demo-training-prepare-v7-pile.sh",
    "chars": 1408,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v7/train_temp/demo-training-prepare.sh",
    "chars": 1929,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v7/train_temp/demo-training-run-v7-pile.sh",
    "chars": 2170,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v7/train_temp/demo-training-run.sh",
    "chars": 3078,
    "preview": "#!/bin/bash\n############################################################################################################"
  },
  {
    "path": "RWKV-v7/train_temp/rwkv7_train_simplified.py",
    "chars": 15148,
    "preview": "########################################################################################################\r\n# The RWKV Lan"
  },
  {
    "path": "RWKV-v7/train_temp/src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "RWKV-v7/train_temp/src/binidx.py",
    "chars": 8605,
    "preview": "import os\nimport torch\nimport numpy as np\nimport struct\nfrom functools import lru_cache\nfrom itertools import accumulate"
  },
  {
    "path": "RWKV-v7/train_temp/src/dataset.py",
    "chars": 2537,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/train_temp/src/model.py",
    "chars": 18659,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v7/train_temp/src/trainer.py",
    "chars": 8607,
    "preview": "import os, math, time, datetime, subprocess\nimport torch\nfrom torch.utils.data import DataLoader\nimport pytorch_lightnin"
  },
  {
    "path": "RWKV-v7/train_temp/train.py",
    "chars": 12090,
    "preview": "########################################################################################################\n# The RWKV Lang"
  },
  {
    "path": "RWKV-v8/251014_rosa_1bit_layer.py",
    "chars": 4303,
    "preview": "import torch\nfrom torch import nn\n\n#####################################################################################"
  },
  {
    "path": "RWKV-v8/251014_rosa_1bit_train.py",
    "chars": 7101,
    "preview": "import torch, random\nfrom torch import nn\nimport torch.nn.functional as F\ntorch.backends.cuda.matmul.allow_tf32=True; to"
  },
  {
    "path": "RWKV-v8/251014_rosa_onlyemb_train.py",
    "chars": 3806,
    "preview": "import torch, random\nfrom torch import nn\nimport torch.nn.functional as F\ntorch.backends.cuda.matmul.allow_tf32=True; to"
  },
  {
    "path": "RWKV-v8/251016_rosa_1bit_run.py",
    "chars": 5117,
    "preview": "import torch, random\r\nfrom torch import nn\r\nimport torch.nn.functional as F\r\ntorch.backends.cuda.matmul.allow_tf32=True;"
  },
  {
    "path": "RWKV-v8/251018_rosa_4bit_run.py",
    "chars": 5322,
    "preview": "import torch, random\r\nfrom torch import nn\r\nimport torch.nn.functional as F\r\ntorch.backends.cuda.matmul.allow_tf32=True;"
  },
  {
    "path": "RWKV-v8/251024_rosaQKV_run.py",
    "chars": 16924,
    "preview": "import random, torch, math\r\nfrom types import SimpleNamespace\r\nimport torch, random\r\nfrom torch import nn\r\nimport torch."
  },
  {
    "path": "RWKV-v8/251105_reverse_run.py",
    "chars": 16954,
    "preview": "import random, torch, math\r\nfrom types import SimpleNamespace\r\nimport torch, random\r\nfrom torch import nn\r\nimport torch."
  },
  {
    "path": "RWKV-v8/260212_rosa1bitLM_L12.py",
    "chars": 13236,
    "preview": "########################################################################################################\r\n# The RWKV Lan"
  },
  {
    "path": "RWKV-v8/260222_rosa4bitLM_L12.py",
    "chars": 13222,
    "preview": "########################################################################################################\r\n# The RWKV Lan"
  },
  {
    "path": "RWKV-v8/README.md",
    "chars": 1627,
    "preview": "# RWKV-8 \"Heron\" with ROSA (Rapid Online Suffix Automaton)\n\n### Community ROSA Projects\n\nhttps://github.com/wjie98/rosa_"
  },
  {
    "path": "RWKV-v8/cuda/wkv7_cuda.cu",
    "chars": 4374,
    "preview": "#include <cuda_bf16.h>\n#include <assert.h>\n\nusing bf = float;\n__device__ inline float to_float(const bf & u) { return u;"
  },
  {
    "path": "RWKV-v8/cuda/wkv7_op.cpp",
    "chars": 1976,
    "preview": "#include <torch/extension.h>\n#include <cuda_bf16.h>\nusing bf = float;\n\nvoid cuda_forward(int B, int T, int H, bf*w, bf*q"
  },
  {
    "path": "Research/rwkv7-g0-7.2b.md",
    "chars": 4590,
    "preview": "I added 2T tokens to RWKV-6 \"World-3\" 7.6B and trained RWKV-7 \"G0\" 7.2B, likely the strongest pure RNN ever existed.\n\nDo"
  }
]

// ... and 3 more files (download for full content)

About this extraction

This page contains the full source code of the BlinkDL/RWKV-LM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 139 files (7.7 MB), approximately 2.0M tokens, and a symbol index with 1025 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!