[
  {
    "path": "KoSBERT/Clustering.py",
    "content": "from sentence_transformers import SentenceTransformer, util\nimport numpy as np\n\nmodel_path = '../Checkpoint/KoSBERT/kosbert-klue-bert-base'\n\nembedder = SentenceTransformer(model_path)\n\n# Corpus with example sentences\ncorpus = ['한 남자가 음식을 먹는다.',\n          '한 남자가 빵 한 조각을 먹는다.',\n          '그 여자가 아이를 돌본다.',\n          '한 남자가 말을 탄다.',\n          '한 여자가 바이올린을 연주한다.',\n          '두 남자가 수레를 숲 솦으로 밀었다.',\n          '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.',\n          '원숭이 한 마리가 드럼을 연주한다.',\n          '치타 한 마리가 먹이 뒤에서 달리고 있다.',\n          '한 남자가 파스타를 먹는다.',\n          '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.',\n          '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\ncorpus_embeddings = embedder.encode(corpus)\n\n# Then, we perform k-means clustering using sklearn:\nfrom sklearn.cluster import KMeans\n\nnum_clusters = 5\nclustering_model = KMeans(n_clusters=num_clusters)\nclustering_model.fit(corpus_embeddings)\ncluster_assignment = clustering_model.labels_\n\nclustered_sentences = [[] for i in range(num_clusters)]\nfor sentence_id, cluster_id in enumerate(cluster_assignment):\n    clustered_sentences[cluster_id].append(corpus[sentence_id])\n\nfor i, cluster in enumerate(clustered_sentences):\n    print(\"Cluster \", i+1)\n    print(cluster)\n    print(\"\")\n"
  },
  {
    "path": "KoSBERT/README.md",
    "content": "# KoSentenceBERT\n[[Github]](https://github.com/UKPLab/sentence-transformers) Official implementation of SBERT. <br>\nKorean SentenceBERT : Korean Sentence Embeddings using Siamese BERT-Networks.\n\n## Quick start\n- If you want to do inference quickly, download the pre-trained models and then you can start some downstream tasks.\n```\nbash get_model_checkpoint.sh\npython SemanticSearch.py\n```\n\n## Training\n- Before training or evaluation, please download the datasets by running\n    ```\n    bash get_model_dataset.sh\n    ```\n- Two stage training\n    - First step, training NLI dataset \n    \n        ```\n        python training_nli.py --model klue/bert-base --batch 32 --evaluation_steps 1000 --epochs 1\n        ```\n    - Second step, continued training STS dataset \n    \n        ```\n        python con_training_sts.py --model klue/bert-base --batch 32 --evaluation_steps 1000 --epochs 4\n        ```\n    \n- Run Examples\n  ```\n  bash run_example.sh\n  ```\n### Hyperparameters\n- Training NLI\n  1. Pooling Method: MEAN strategy\n  2. Batch Size: 32\n  3. Evaluation Steps: 1000\n  4. Epochs: 1(BERT), 2(RoBERTa)\n  \n- Continued Training STS\n  1. Pooling Method: MEAN strategy\n  2. Batch Size: 32\n  3. Evaluation Steps: 1000\n  4. Epochs: 4\n\n### Semantic Search\n```\npython SemanticSearch.py\n```\n```python\nfrom sentence_transformers import SentenceTransformer, util\nimport numpy as np\n\nmodel_path = '../Checkpoint/KoSBERT/kosbert-klue-bert-base'\n\nembedder = SentenceTransformer(model_path)\n\n# Corpus with example sentences\ncorpus = ['한 남자가 음식을 먹는다.',\n          '한 남자가 빵 한 조각을 먹는다.',\n          '그 여자가 아이를 돌본다.',\n          '한 남자가 말을 탄다.',\n          '한 여자가 바이올린을 연주한다.',\n          '두 남자가 수레를 숲 솦으로 밀었다.',\n          '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.',\n          '원숭이 한 마리가 드럼을 연주한다.',\n          '치타 한 마리가 먹이 뒤에서 달리고 있다.']\n\ncorpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)\n\n# Query sentences:\nqueries = ['한 남자가 파스타를 먹는다.',\n           '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.',\n           '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\n# Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity\ntop_k = 5\nfor query in queries:\n    query_embedding = embedder.encode(query, convert_to_tensor=True)\n    cos_scores = util.pytorch_cos_sim(query_embedding, corpus_embeddings)[0]\n    cos_scores = cos_scores.cpu()\n\n    #We use np.argpartition, to only partially sort the top_k results\n    top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]\n\n    print(\"\\n\\n======================\\n\\n\")\n    print(\"Query:\", query)\n    print(\"\\nTop 5 most similar sentences in corpus:\")\n\n    for idx in top_results[0:top_k]:\n        print(corpus[idx].strip(), \"(Score: %.4f)\" % (cos_scores[idx]))\n\n```\n\n- Results are as follows :\n\n```\n\nQuery: 한 남자가 파스타를 먹는다.\n\nTop 5 most similar sentences in corpus:\n한 남자가 음식을 먹는다. (Score: 0.6141)\n한 남자가 빵 한 조각을 먹는다. (Score: 0.5952)\n한 남자가 말을 탄다. (Score: 0.1231)\n한 남자가 담으로 싸인 땅에서 백마를 타고 있다. (Score: 0.0752)\n두 남자가 수레를 숲 솦으로 밀었다. (Score: 0.0486)\n\n\n======================\n\n\nQuery: 고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.\n\nTop 5 most similar sentences in corpus:\n원숭이 한 마리가 드럼을 연주한다. (Score: 0.6656)\n치타 한 마리가 먹이 뒤에서 달리고 있다. (Score: 0.2988)\n한 여자가 바이올린을 연주한다. (Score: 0.1566)\n한 남자가 말을 탄다. (Score: 0.1112)\n한 남자가 담으로 싸인 땅에서 백마를 타고 있다. (Score: 0.0262)\n\n\n======================\n\n\nQuery: 치타가 들판을 가로 질러 먹이를 쫓는다.\n\nTop 5 most similar sentences in corpus:\n치타 한 마리가 먹이 뒤에서 달리고 있다. (Score: 0.7570)\n두 남자가 수레를 숲 솦으로 밀었다. (Score: 0.3658)\n원숭이 한 마리가 드럼을 연주한다. (Score: 0.3583)\n한 남자가 말을 탄다. (Score: 0.0505)\n그 여자가 아이를 돌본다. (Score: -0.0087)\n```\n\n### Clustering \n```\npython Clustering.py\n```\n```python\nfrom sentence_transformers import SentenceTransformer, util\nimport numpy as np\n\nmodel_path = '../Checkpoint/KoSBERT/kosbert-klue-bert-base'\n\nembedder = SentenceTransformer(model_path)\n\n# Corpus with example sentences\ncorpus = ['한 남자가 음식을 먹는다.',\n          '한 남자가 빵 한 조각을 먹는다.',\n          '그 여자가 아이를 돌본다.',\n          '한 남자가 말을 탄다.',\n          '한 여자가 바이올린을 연주한다.',\n          '두 남자가 수레를 숲 솦으로 밀었다.',\n          '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.',\n          '원숭이 한 마리가 드럼을 연주한다.',\n          '치타 한 마리가 먹이 뒤에서 달리고 있다.',\n          '한 남자가 파스타를 먹는다.',\n          '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.',\n          '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\ncorpus_embeddings = embedder.encode(corpus)\n\n# Then, we perform k-means clustering using sklearn:\nfrom sklearn.cluster import KMeans\n\nnum_clusters = 5\nclustering_model = KMeans(n_clusters=num_clusters)\nclustering_model.fit(corpus_embeddings)\ncluster_assignment = clustering_model.labels_\n\nclustered_sentences = [[] for i in range(num_clusters)]\nfor sentence_id, cluster_id in enumerate(cluster_assignment):\n    clustered_sentences[cluster_id].append(corpus[sentence_id])\n\nfor i, cluster in enumerate(clustered_sentences):\n    print(\"Cluster \", i+1)\n    print(cluster)\n    print(\"\")\n```\n- Results are as follows:\n```\nCluster  1\n['한 남자가 음식을 먹는다.', '한 남자가 빵 한 조각을 먹는다.', '한 남자가 파스타를 먹는다.']\n\nCluster  2\n['원숭이 한 마리가 드럼을 연주한다.', '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.']\n\nCluster  3\n['한 남자가 말을 탄다.', '두 남자가 수레를 숲 솦으로 밀었다.', '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.']\n\nCluster  4\n['치타 한 마리가 먹이 뒤에서 달리고 있다.', '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\nCluster  5\n['그 여자가 아이를 돌본다.', '한 여자가 바이올린을 연주한다.']\n```\n"
  },
  {
    "path": "KoSBERT/SemanticSearch.py",
    "content": "from sentence_transformers import SentenceTransformer, util\nimport numpy as np\n\nmodel_path = '../Checkpoint/KoSBERT/kosbert-klue-bert-base'\n\nembedder = SentenceTransformer(model_path)\n\n# Corpus with example sentences\ncorpus = ['한 남자가 음식을 먹는다.',\n          '한 남자가 빵 한 조각을 먹는다.',\n          '그 여자가 아이를 돌본다.',\n          '한 남자가 말을 탄다.',\n          '한 여자가 바이올린을 연주한다.',\n          '두 남자가 수레를 숲 솦으로 밀었다.',\n          '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.',\n          '원숭이 한 마리가 드럼을 연주한다.',\n          '치타 한 마리가 먹이 뒤에서 달리고 있다.']\n\ncorpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)\n\n# Query sentences:\nqueries = ['한 남자가 파스타를 먹는다.',\n           '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.',\n           '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\n# Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity\ntop_k = 5\nfor query in queries:\n    query_embedding = embedder.encode(query, convert_to_tensor=True)\n    cos_scores = util.pytorch_cos_sim(query_embedding, corpus_embeddings)[0]\n    cos_scores = cos_scores.cpu()\n\n    #We use np.argpartition, to only partially sort the top_k results\n    top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]\n\n    print(\"\\n\\n======================\\n\\n\")\n    print(\"Query:\", query)\n    print(\"\\nTop 5 most similar sentences in corpus:\")\n\n    for idx in top_results[0:top_k]:\n        print(corpus[idx].strip(), \"(Score: %.4f)\" % (cos_scores[idx]))\n\n\n"
  },
  {
    "path": "KoSBERT/con_training_sts.py",
    "content": "from torch.utils.data import DataLoader\nimport math\nfrom sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses, util, InputExample\nfrom sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\nimport logging\nfrom datetime import datetime\nimport os\nimport gzip\nimport csv\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--model', type=str, default='klue/bert-base')\nparser.add_argument('--batch', type=int, default=32)\nparser.add_argument('--evaluation_steps', type=int, default=1000)\nparser.add_argument('--epochs', type=int, default=4)\nargs = parser.parse_args()\n\nlogging.basicConfig(format='%(asctime)s - %(message)s',\n                    datefmt='%Y-%m-%d %H:%M:%S',\n                    level=logging.INFO,\n                    handlers=[LoggingHandler()])\n\nmodel_name = './output/training_nli_'+args.model.replace(\"/\", \"-\")\n\ntrain_batch_size = args.batch\nnum_epochs = args.epochs\n\nmodel_save_path = 'output/kosbert-'+args.model.replace(\"/\", \"-\")\n\nmodel = SentenceTransformer(model_name)\n\nlogging.info(\"Read STSbenchmark train dataset\")\n\ntrain_samples = []\ndev_samples = []\ntest_samples = []\nwith open('../Dataset/tune_sts_dev.tsv', 'rt', encoding='utf-8') as fIn:\n    lines = fIn.readlines()\n    for line in lines:\n        s1, s2, score = line.split('\\t')\n        score = score.strip()\n        score = float(score) / 5.0\n        dev_samples.append(InputExample(texts= [s1,s2], label=score))\n\nwith open('../Dataset/tune_sts_test.tsv', 'rt', encoding='utf-8') as fIn:\n    lines = fIn.readlines()\n    for line in lines:\n        s1, s2, score = line.split('\\t')\n        score = score.strip()\n        score = float(score) / 5.0\n        test_samples.append(InputExample(texts= [s1,s2], label=score))\n\nwith open('../Dataset/tune_sts_train.tsv', 'rt', encoding='utf-8') as fIn:\n    lines = fIn.readlines()\n    for line in lines:\n        s1, s2, score = line.split('\\t')\n        score = score.strip()\n        score = float(score) / 5.0\n        train_samples.append(InputExample(texts= [s1,s2], label=score))\n\ntrain_dataset = SentencesDataset(train_samples, model)\ntrain_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)\ntrain_loss = losses.CosineSimilarityLoss(model=model)\n\n\n# Development set: Measure correlation between cosine score and gold labels\nlogging.info(\"Read STSbenchmark dev dataset\")\nevaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')\n\nwarmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up\nlogging.info(\"Warmup-steps: {}\".format(warmup_steps))\n\n\n# Train the model\nmodel.fit(train_objectives=[(train_dataloader, train_loss)],\n          evaluator=evaluator,\n          epochs=num_epochs,\n          evaluation_steps=args.evaluation_steps,\n          warmup_steps=warmup_steps,\n          output_path=model_save_path)\n\n\n##############################################################################\n#\n# Load the stored model and evaluate its performance on STS benchmark dataset\n#\n##############################################################################\n\nmodel = SentenceTransformer(model_save_path)\ntest_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')\ntest_evaluator(model, output_path=model_save_path)\n"
  },
  {
    "path": "KoSBERT/output/empty.txt",
    "content": "."
  },
  {
    "path": "KoSBERT/run_example.sh",
    "content": "#!/bin/bash\n\n# bert-base\necho \"First Step Training NLI Dataset (BERT-BASE)\"\nCUDA_VISIBLE_DEVICES=0 python training_nli.py --model klue/bert-base --batch 32 --evaluation_steps 1000 --epochs 1\necho \"Second Step Continuously Training STS Dataset (BERT-BASE)\"\nCUDA_VISIBLE_DEVICES=0 python con_training_sts.py --model klue/bert-base --batch 32 --evaluation_steps 1000 --epochs 4\n\n# roberta-base\necho \"First Step Training NLI Dataset (ROBERTA-BASE)\"\nCUDA_VISIBLE_DEVICES=0 python training_nli.py --model klue/roberta-base --batch 32 --evaluation_steps 1000 --epochs 1\necho \"Second Step Continuously Training STS Dataset (ROBERTA-BASE)\"\nCUDA_VISIBLE_DEVICES=0 python con_training_sts.py --model klue/roberta-base --batch 32 --evaluation_steps 1000 --epochs 4\n\n# roberta-large\necho \"First Step Training NLI Dataset (ROBERAT-LARGE)\"\nCUDA_VISIBLE_DEVICES=0 python training_nli.py --model klue/roberta-large --batch 32 --evaluation_steps 1000 --epochs 1\necho \"Second Step Continuously Training STS Dataset (ROBERTA-LARGE)\"\nCUDA_VISIBLE_DEVICES=0 python con_training_sts.py --model klue/roberta-large --batch 32 --evaluation_steps 1000 --epochs 4\n\n"
  },
  {
    "path": "KoSBERT/training_nli.py",
    "content": "from torch.utils.data import DataLoader\nimport math\nfrom sentence_transformers import models, losses\nfrom sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer, util, InputExample\nfrom sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\nimport logging\nfrom datetime import datetime\nimport sys\nimport os\nimport gzip\nimport csv\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--model', type=str, default='klue/bert-base')\nparser.add_argument('--batch', type=int, default=32)\nparser.add_argument('--evaluation_steps', type=int, default=1000)\nparser.add_argument('--epochs', type=int, default=1)\n\nargs = parser.parse_args()\n\nmodel_name = args.model\n\ntrain_batch_size = args.batch\n\nmodel_save_path = 'output/training_nli_'+model_name.replace(\"/\", \"-\")#+'-'+datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n\nword_embedding_model = models.Transformer(model_name)\n\n# Apply mean pooling to get one fixed sized sentence vector\npooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),\n                               pooling_mode_mean_tokens=True,\n                               pooling_mode_cls_token=False,\n                               pooling_mode_max_tokens=False)\n\nmodel = SentenceTransformer(modules=[word_embedding_model, pooling_model])\n\nlogging.info(\"Read AllNLI train dataset\")\n\nlabel2int = {\"contradiction\": 0, \"entailment\": 1, \"neutral\": 2}\ntrain_samples = []\n\nwith open('../Dataset/snli_1.0_train.ko.tsv', \"rt\", encoding=\"utf-8\") as fIn:\n    lines = fIn.readlines()\n    for line in lines:\n        s1, s2, label = line.split('\\t')\n        label = label2int[label.strip()]\n        train_samples.append(InputExample(texts=[s1, s2], label=label))\n\ntrain_dataset = SentencesDataset(train_samples, model=model)\ntrain_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)\ntrain_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int))\n\n\n#Read STSbenchmark dataset and use it as development set\nlogging.info(\"Read STSbenchmark dev dataset\")\ndev_samples = []\n\nwith open('../Dataset/tune_sts_dev.tsv', 'rt', encoding='utf-8') as fIn:\n    lines = fIn.readlines()\n    for line in lines:\n        s1, s2, score = line.split('\\t')\n        score = score.strip()\n        score = float(score) / 5.0\n        dev_samples.append(InputExample(texts= [s1,s2], label=score))\n\ndev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, batch_size=train_batch_size, name='sts-dev')\n\nnum_epochs = args.epochs\n\nwarmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up\nlogging.info(\"Warmup-steps: {}\".format(warmup_steps))\n\n# Train the model\nmodel.fit(train_objectives=[(train_dataloader, train_loss)],\n          evaluator=dev_evaluator,\n          epochs=num_epochs,\n          evaluation_steps=args.evaluation_steps,\n          warmup_steps=warmup_steps,\n          output_path=model_save_path\n          )\n\n\n\n##############################################################################\n#\n# Load the stored model and evaluate its performance on STS benchmark dataset\n#\n##############################################################################\n\ntest_samples = []\nwith open('../Dataset/tune_sts_test.tsv', 'rt', encoding='utf-8') as fIn:\n    lines = fIn.readlines()\n    for line in lines:\n        s1, s2, score = line.split('\\t')\n        score = score.strip()\n        score = float(score) / 5.0\n        test_samples.append(InputExample(texts=[s1,s2], label=score))\n\nprint(\"\\n\\n\\n\")\nprint(\"======================TEST===================\")\nprint(\"\\n\\n\\n\")\nmodel = SentenceTransformer(model_save_path)\nprint(f\"model save path > {model_save_path}\")\ntest_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, batch_size=train_batch_size, name='sts-test')\ntest_evaluator(model, output_path=model_save_path)\n"
  },
  {
    "path": "KoSentenceT5/README.md",
    "content": "# KoSentenceT5\nKoSentenceT5 : Korean Sentence Embeddings using T5. <br>\n> **Warning** <br>\n> This repository uses ETRI-T5 model and does not provide it. You can download T5 model from [here](https://aiopen.etri.re.kr/service_dataset.php).\n\n## Training \n- Before training or evaluation, please download the datasets by running\n```\nbash get_model_dataset.sh\n```\n### Train KoSentenceT5\n  ```\n  python main.py \\\n    --model etri-t5 \\\n    --multi_gpu True \\\n    --test False \\\n    --max_len 110 \\\n    --batch_size 64 \\\n    --epochs 2 \\\n    --eval_steps 125 \\\n    --lr 0.0001 \\\n    --warmup_ratio 0.01 \\\n    --temperature 0.05 \\\n    --path_to_data ../Dataset/ \\\n    --train_data train_nli.tsv \\\n    --valid_data valid_sts.tsv\n  ```\n### Evaluation\n  ```\n  python main.py \\\n    --model etri-t5 \\\n    --train False \\\n    --test True \\\n    --max_len 110 \\\n    --batch_size 64 \\\n    --temperature 0.05 \\\n    --path_to_data ../Dataset/ \\\n    --test_data test_sts.tsv \\\n  ```\n\n### Run Examples\n```\nbash run_example.sh\n```\n"
  },
  {
    "path": "KoSentenceT5/apex/RNN/README.md",
    "content": "Under construction...\n"
  },
  {
    "path": "KoSentenceT5/apex/RNN/RNNBackend.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\n\nimport torch.nn.functional as F\n\nimport math\n\n\ndef is_iterable(maybe_iterable):\n    return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)\n\n\ndef flatten_list(tens_list):\n    \"\"\"\n    flatten_list\n    \"\"\"\n    if not is_iterable(tens_list):\n        return tens_list\n    \n    return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )\n\n    \n#These modules always assumes batch_first\nclass bidirectionalRNN(nn.Module):\n    \"\"\"\n    bidirectionalRNN\n    \"\"\"\n    def __init__(self, inputRNN, num_layers=1, dropout = 0):\n        super(bidirectionalRNN, self).__init__()\n        self.dropout = dropout\n        self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)\n        self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)\n        self.rnns = nn.ModuleList([self.fwd, self.bckwrd])\n        \n    #collect hidden option will return all hidden/cell states from entire RNN\n    def forward(self, input, collect_hidden=False):\n        \"\"\"\n        forward()\n        \"\"\"\n        seq_len = input.size(0)\n        bsz = input.size(1)\n\n        fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))\n        bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))\n        \n        output = torch.cat( [fwd_out, bckwrd_out], -1 )\n        hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )\n\n        return output, hiddens\n\n    def reset_parameters(self):\n        \"\"\"\n        reset_parameters()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_parameters()\n        \n    def init_hidden(self, bsz):\n        \"\"\"\n        init_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_hidden(bsz)\n\n    def detach_hidden(self):\n        \"\"\"\n        detach_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.detachHidden()\n        \n    def reset_hidden(self, bsz):\n        \"\"\"\n        reset_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_hidden(bsz)\n\n    def init_inference(self, bsz):    \n        \"\"\"\n        init_inference()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_inference(bsz)\n\n   \n#assumes hidden_state[0] of inputRNN is output hidden state\n#constructor either takes an RNNCell or list of RNN layers\nclass stackedRNN(nn.Module):        \n    \"\"\"\n    stackedRNN\n    \"\"\"\n    def __init__(self, inputRNN, num_layers=1, dropout=0):\n        super(stackedRNN, self).__init__()\n        \n        self.dropout = dropout\n        \n        if isinstance(inputRNN, RNNCell):\n            self.rnns = [inputRNN]\n            for i in range(num_layers-1):\n                self.rnns.append(inputRNN.new_like(inputRNN.output_size))\n        elif isinstance(inputRNN, list):\n            assert len(inputRNN) == num_layers, \"RNN list length must be equal to num_layers\"\n            self.rnns=inputRNN\n        else:\n            raise RuntimeError()\n        \n        self.nLayers = len(self.rnns)\n        \n        self.rnns = nn.ModuleList(self.rnns)\n\n\n    '''\n    Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])\n    If collect hidden will also return Tuple(\n        [n_hidden_states][sequence steps] Tensor([layer][batch size][features])\n    )\n    If not collect hidden will also return Tuple(\n        [n_hidden_states] Tensor([layer][batch size][features])\n    '''\n    def forward(self, input, collect_hidden=False, reverse=False):\n        \"\"\"\n        forward()\n        \"\"\"\n        seq_len = input.size(0)\n        bsz = input.size(1)\n        inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)\n\n        hidden_states = [[] for i in range(self.nLayers)]\n        outputs = []\n\n        for seq in inp_iter:\n            for layer in range(self.nLayers):\n\n                if layer == 0:\n                    prev_out = input[seq]\n                    \n                outs = self.rnns[layer](prev_out)\n\n                if collect_hidden:\n                    hidden_states[layer].append(outs)\n                elif seq == seq_len-1:\n                    hidden_states[layer].append(outs)\n                    \n                prev_out = outs[0]\n\n            outputs.append(prev_out)\n\n        if reverse:\n            outputs = list(reversed(outputs))\n        '''\n        At this point outputs is in format:\n        list( [seq_length] x Tensor([bsz][features]) )\n        need to convert it to:\n        list( Tensor([seq_length][bsz][features]) )\n        '''\n        output = flatten_list(outputs)\n\n        '''\n        hidden_states at this point is in format:\n        list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )\n        need to convert it to:\n          For not collect hidden:\n            list( [hidden_states] x Tensor([layer][bsz][features]) )\n          For collect hidden:\n            list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )\n        '''\n        if not collect_hidden:\n            seq_len = 1\n        n_hid = self.rnns[0].n_hidden_states\n        new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]\n\n\n        for i in range(n_hid):\n            for j in range(seq_len):\n                for k in range(self.nLayers):\n                    new_hidden[i][j][k] = hidden_states[k][j][i]\n\n        hidden_states = new_hidden\n        #Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )\n        #Reverse seq_length if reverse\n        if reverse:\n            hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)\n\n        #flatten layer dimension into tensor\n        hiddens = list( list(\n            flatten_list(seq) for seq in hidden )\n                        for hidden in hidden_states )\n        \n        #Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )\n        #Remove seq_length dimension if not collect_hidden\n        if not collect_hidden:\n            hidden_states = list( entry[0] for entry in hidden_states)\n        return output, hidden_states\n    \n    def reset_parameters(self):\n        \"\"\"\n        reset_parameters()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_parameters()\n        \n    def init_hidden(self, bsz):\n        \"\"\"\n        init_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_hidden(bsz)\n\n    def detach_hidden(self):\n        \"\"\"\n        detach_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.detach_hidden()\n        \n    def reset_hidden(self, bsz):\n        \"\"\"\n        reset_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_hidden(bsz)\n\n    def init_inference(self, bsz):    \n        \"\"\" \n        init_inference()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_inference(bsz)\n\nclass RNNCell(nn.Module):\n    \"\"\" \n    RNNCell \n    gate_multiplier is related to the architecture you're working with\n    For LSTM-like it will be 4 and GRU-like will be 3.\n    Always assumes input is NOT batch_first.\n    Output size that's not hidden size will use output projection\n    Hidden_states is number of hidden states that are needed for cell\n    if one will go directly to cell as tensor, if more will go as list\n    \"\"\"\n    def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):\n        super(RNNCell, self).__init__()\n\n        self.gate_multiplier = gate_multiplier\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.cell = cell\n        self.bias = bias\n        self.output_size = output_size\n        if output_size is None:\n            self.output_size = hidden_size\n\n        self.gate_size = gate_multiplier * self.hidden_size\n        self.n_hidden_states = n_hidden_states\n\n        self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))\n        self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size))\n\n        #Check if there's recurrent projection\n        if(self.output_size != self.hidden_size):\n            self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))\n\n        self.b_ih = self.b_hh = None\n        if self.bias:\n            self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))\n            self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))\n            \n        #hidden states for forward\n        self.hidden = [ None for states in range(self.n_hidden_states)]\n\n        self.reset_parameters()\n\n    def new_like(self, new_input_size=None):\n        \"\"\"\n        new_like()\n        \"\"\"\n        if new_input_size is None:\n            new_input_size = self.input_size\n            \n        return type(self)(self.gate_multiplier,\n                       new_input_size,\n                       self.hidden_size,\n                       self.cell,\n                       self.n_hidden_states,\n                       self.bias,\n                       self.output_size)\n\n    \n    #Use xavier where we can (weights), otherwise use uniform (bias)\n    def reset_parameters(self, gain=1):\n        \"\"\"\n        reset_parameters()\n        \"\"\"\n        stdev = 1.0 / math.sqrt(self.hidden_size)\n        for param in self.parameters():\n            param.data.uniform_(-stdev, stdev)\n    '''\n    Xavier reset:\n    def reset_parameters(self, gain=1):\n        stdv = 1.0 / math.sqrt(self.gate_size)\n\n        for param in self.parameters():\n            if (param.dim() > 1):\n                torch.nn.init.xavier_normal(param, gain)\n            else:\n                param.data.uniform_(-stdv, stdv)\n    '''\n    def init_hidden(self, bsz):\n        \"\"\"\n        init_hidden()\n        \"\"\"\n        for param in self.parameters():\n            if param is not None:\n                a_param = param\n                break\n\n        for i, _ in enumerate(self.hidden):\n            if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):\n\n                if i==0:\n                    hidden_size = self.output_size\n                else:\n                    hidden_size = self.hidden_size\n\n                tens = a_param.data.new(bsz, hidden_size).zero_()\n                self.hidden[i] = Variable(tens, requires_grad=False)\n            \n        \n    def reset_hidden(self, bsz):\n        \"\"\"\n        reset_hidden()\n        \"\"\"\n        for i, _ in enumerate(self.hidden):\n            self.hidden[i] = None\n        self.init_hidden(bsz)\n\n    def detach_hidden(self):\n        \"\"\"\n        detach_hidden()\n        \"\"\"\n        for i, _ in enumerate(self.hidden):\n            if self.hidden[i] is None:\n                raise RuntimeError(\"Must initialize hidden state before you can detach it\")\n        for i, _ in enumerate(self.hidden):\n            self.hidden[i] = self.hidden[i].detach()\n        \n    def forward(self, input):\n        \"\"\"\n        forward()\n        if not inited or bsz has changed this will create hidden states\n        \"\"\"\n        self.init_hidden(input.size()[0])\n\n        hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden\n        self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)\n        if(self.n_hidden_states > 1):\n            self.hidden = list(self.hidden)\n        else:\n            self.hidden=[self.hidden]\n\n        if self.output_size != self.hidden_size:\n            self.hidden[0] = F.linear(self.hidden[0], self.w_ho)\n\n        return tuple(self.hidden)\n"
  },
  {
    "path": "KoSentenceT5/apex/RNN/__init__.py",
    "content": "from .models import LSTM, GRU, ReLU, Tanh, mLSTM\n\n__all__ = ['models']\n"
  },
  {
    "path": "KoSentenceT5/apex/RNN/cells.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .RNNBackend import RNNCell\n\nfrom torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend\n\nimport math \n\n\nclass mLSTMRNNCell(RNNCell):\n    \"\"\"\n    mLSTMRNNCell\n    \"\"\"\n\n    def __init__(self, input_size, hidden_size, bias = False, output_size = None):\n        gate_multiplier = 4\n        super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)\n\n        self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size))\n        self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size))\n\n        self.reset_parameters()\n\n    def forward(self, input):\n        \"\"\"\n        mLSTMRNNCell.forward()\n        \"\"\"\n        #if not inited or bsz has changed this will create hidden states\n        self.init_hidden(input.size()[0])\n\n        hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden\n\n        self.hidden = list(\n                           self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,\n                           b_ih=self.b_ih, b_hh=self.b_hh)\n        )\n        \n        if self.output_size != self.hidden_size:\n            self.hidden[0] = F.linear(self.hidden[0], self.w_ho)\n        return tuple(self.hidden)\n\n\n    def new_like(self, new_input_size=None):\n        if new_input_size is None:\n            new_input_size = self.input_size\n        \n        return type(self)(\n            new_input_size,\n            self.hidden_size,\n            self.bias,\n            self.output_size)\n\ndef mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):\n    \"\"\"\n    mLSTMCell\n    \"\"\"\n\n    if input.is_cuda:\n        igates = F.linear(input, w_ih)\n        m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)\n        hgates = F.linear(m, w_hh)\n\n        state = fusedBackend.LSTMFused.apply\n        return state(igates, hgates, hidden[1], b_ih, b_hh)\n\n    hx, cx = hidden\n    \n    m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)\n    gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)\n\n    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)\n\n    ingate = F.sigmoid(ingate)\n    forgetgate = F.sigmoid(forgetgate)\n    cellgate = F.tanh(cellgate)\n    outgate = F.sigmoid(outgate)\n    \n    cy = (forgetgate * cx) + (ingate * cellgate)\n    hy = outgate * F.tanh(cy)\n    \n    return hy, cy\n                                                                            \n"
  },
  {
    "path": "KoSentenceT5/apex/RNN/models.py",
    "content": "import torch\n\nfrom torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell\n\nfrom .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell\nfrom .cells import mLSTMRNNCell, mLSTMCell\n\ndef toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):\n    \"\"\"\n    :class:`toRNNBackend`\n    \"\"\"\n\n    if bidirectional:\n        return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)\n    else:\n        return stackedRNN(inputRNN, num_layers, dropout = dropout)\n\n\ndef LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`LSTM`\n    \"\"\"\n    inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\ndef GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`GRU`\n    \"\"\"\n    inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\ndef ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`ReLU`\n    \"\"\"\n    inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\ndef Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`Tanh`\n    \"\"\"\n    inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n        \ndef mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`mLSTM`\n    \"\"\"\n    inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/__init__.py",
    "content": "# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten\nimport torch\nimport warnings\n\nif torch.distributed.is_available():\n    from . import parallel\n\nfrom . import amp\nfrom . import fp16_utils\n\n# For optimizers and normalization there is no Python fallback.\n# Absence of cuda backend is a hard error.\n# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda\n# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext\n# so they expect those backends to be available, but for some reason they actually aren't\n# available (for example because they built improperly in a way that isn't revealed until\n# load time) the error message is timely and visible.\nfrom . import optimizers\nfrom . import normalization\nfrom . import pyprof\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/README.md",
    "content": "# amp: Automatic Mixed Precision\n\n## Annotating User Functions\n\nNearly all PyTorch user code needs nothing more than the two steps\nabove to use amp. After all, custom layers are built out of simpler\nPyTorch components, and amp already can see those.\n\nHowever, any custom C++ or CUDA code is outside of amp's (default)\nview of things. For example, suppose I implemented a new recurrent\ncell called a \"forgetful recurrent unit\" that calls directly into a\nCUDA backend:\n\n```python\nfrom backend import FRUBackend\n\ndef fru(input, hidden, weight, bias):\n    # call to CUDA code\n    FRUBackend(input, hidden, weight, bias)\n```\n\nIn this case, it is possible to get a runtime type mismatch. For\nexample, you might have `input` in fp16, and `weight` in fp32, and amp\ndoesn't have the visibility to insert an appropriate cast.\n\namp exposes two ways to handle \"invisible\" backend code: function\nannotations and explicit registration.\n\n#### Function annotation\n\nThe first way to handle backend code is a set of function annotations:\n\n- `@amp.half_function`\n- `@amp.float_function`\n- `@amp.promote_function`\n\nThese correspond to:\n\n- Cast all arguments to fp16\n- Cast all argumnets fo fp32\n- If there are any type mismatches, cast everything to the widest type\n\nIn our example, we believe that the FRU unit is fp16-safe and will get\nperformance gains from casting its arguments to fp16, so we write:\n\n```python\n@amp.half_function\ndef fru(input, hidden, weight, bias):\n    #...\n```\n\n#### Explicit registration\n\nThe other way to handle backend code is with explicit function\nregistration:\n\n- `amp.register_half_function(module, function_name)`\n- `amp.register_float_function(module, function_name)`\n- `amp.register_promote_function(module, function_name)`\n\nWhen using this API, `module` is the containing class or module for\nthe function, and `function_name` is the _string_ name of the\nfunction. Note that the function must be registered before the call to\n`amp.initalize()`.\n\nFor our FRU unit, we can register the backend function directly:\n\n```python\nimport backend\n\namp.register_half_function(backend, 'FRUBackend')\n```\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/__init__.py",
    "content": "from .amp import init, half_function, float_function, promote_function,\\\n    register_half_function, register_float_function, register_promote_function\nfrom .handle import scale_loss, disable_casts\nfrom .frontend import initialize, state_dict, load_state_dict\nfrom ._amp_state import master_params, _amp_state\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/__version__.py",
    "content": "VERSION = (0, 1, 0)\n__version__ = '.'.join(map(str, VERSION))\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/_amp_state.py",
    "content": "# This is a \"header object\" that allows different amp modules to communicate.\n# I'm a C++ guy, not a python guy.  I decided this approach because it seemed most C++-like.\n# But apparently it's ok:\n# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm\nimport os\nimport torch\n\nTORCH_MAJOR = int(torch.__version__.split('.')[0])\nTORCH_MINOR = int(torch.__version__.split('.')[1])\n\n\nif TORCH_MAJOR == 1 and TORCH_MINOR < 8:\n    from torch._six import container_abcs\nelse:\n    import collections.abc as container_abcs\n\n\nclass AmpState(object):\n    def __init__(self):\n        self.hard_override=False\n        self.allow_incoming_model_not_fp32 = False\n        self.verbosity=1\n\n\n# Attribute stash.  Could also just stash things as global module attributes.\n_amp_state = AmpState()\n\n\ndef warn_or_err(msg):\n    if _amp_state.hard_override:\n        print(\"Warning:  \" + msg)\n    else:\n        raise RuntimeError(msg)\n        # I'm not sure if allowing hard_override is a good idea.\n        # + \"  If you're sure you know what you're doing, supply \" +\n        #                    \"hard_override=True to amp.initialize.\")\n\n\ndef maybe_print(msg, rank0=False):\n    distributed = torch.distributed.is_available() and \\\n        torch.distributed.is_initialized() and \\\n        torch.distributed.get_world_size() > 1\n    if _amp_state.verbosity > 0:\n        if rank0:\n            if distributed:\n                if torch.distributed.get_rank() == 0:\n                    print(msg)\n            else:\n                print(msg)\n        else:\n            print(msg)\n\n\n# def iter_params(param_groups):\n#     for group in param_groups:\n#         for p in group['params']:\n#             yield p\n\n\ndef master_params(optimizer):\n    \"\"\"\n    Generator expression that iterates over the params owned by ``optimizer``.\n\n    Args:\n        optimizer: An optimizer previously returned from ``amp.initialize``.\n    \"\"\"\n    for group in optimizer.param_groups:\n        for p in group['params']:\n            yield p\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/_initialize.py",
    "content": "import torch\nfrom torch._six import string_classes\nimport functools\nimport numpy as np\nimport sys\nfrom types import MethodType\nimport warnings\nfrom ._amp_state import _amp_state, warn_or_err, container_abcs\nfrom .handle import disable_casts\nfrom .scaler import LossScaler\nfrom ._process_optimizer import _process_optimizer\nfrom apex.fp16_utils import convert_network\nfrom ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general\nfrom ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused\n\nif torch.distributed.is_available():\n    from ..parallel import DistributedDataParallel as apex_DDP\n    from ..parallel.LARC import LARC\n\n\ndef to_type(dtype, t):\n    if isinstance(t, torch.Tensor):\n        if not t.is_cuda:\n            # This should not be a hard error, since it may be legitimate.\n            warnings.warn(\"An input tensor was not cuda.\")\n        # GANs require this.\n        # if t.requires_grad:\n        #     warn_or_err(\"input data requires grad.  Since input data is not a model parameter,\\n\"\n        #         \"its gradients will not be properly allreduced by DDP.\")\n        if t.is_floating_point():\n            return t.to(dtype)\n        return t\n    else:\n        # Trust the user's custom batch type, that's all I can do here.\n        return t.to(dtype)\n\n\n# Modified from torch.optim.optimizer.py.  This is a bit more general than casted_args in utils.py.\ndef applier(value, fn):\n    if isinstance(value, torch.Tensor):\n        return fn(value)\n    elif isinstance(value, string_classes):\n        return value\n    elif isinstance(value, np.ndarray):\n        return value\n    elif hasattr(value, \"to\"): # Allow handling of custom batch classes\n        return fn(value)\n    elif isinstance(value, container_abcs.Mapping):\n        return {applier(k, fn) : applier(v, fn) for k, v in value.items()}\n    elif isinstance(value, container_abcs.Iterable):\n        return type(value)(applier(v, fn) for v in value)\n    else:\n        # Do I want this to fire off even if someone chooses to pass something ordinary like\n        # an int or float?  May be more annoying than it's worth.\n        # print(\"Warning:  unrecognized type in applier.  If your input data is a custom class, \"\n        #     \"provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. \"\n        #     \"Amp will check for your custom to() and invoke it to cast the batch's \"\n        #     \"floating-point Tensors to the appropriate type. \"\n        #     \"Also, if your data is a custom class, it is your responsibility to ensure that \"\n        #     \"any Tensors you want to be cuda are already cuda.\"\n        return value\n\n\ndef check_models(models):\n    for model in models:\n        parallel_type = None\n        if isinstance(model, torch.nn.parallel.DistributedDataParallel):\n            parallel_type = \"torch.nn.parallel.DistributedDataParallel\"\n        if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):\n            parallel_type = \"apex.parallel.DistributedDataParallel\"\n        if isinstance(model, torch.nn.parallel.DataParallel):\n            parallel_type = \"torch.nn.parallel.DataParallel\"\n        if parallel_type is not None:\n            raise RuntimeError(\"Incoming model is an instance of {}. \".format(parallel_type) +\n                \"Parallel wrappers should only be applied to the model(s) AFTER \\n\"\n                \"the model(s) have been returned from amp.initialize.\")\n\n\ndef check_params_fp32(models):\n    for model in models:\n        for name, param in model.named_parameters():\n            if param.is_floating_point():\n                if 'Half' in param.type():\n                    warn_or_err(\"Found param {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you do not need to call .half() on your model\\n\"\n                        \"before passing it, no matter what optimization level you choose.\".format(\n                        name, param.type()))\n                elif not param.is_cuda:\n                    warn_or_err(\"Found param {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you need to provide a model with parameters\\n\"\n                        \"located on a CUDA device before passing it no matter what optimization level\\n\"\n                        \"you chose. Use model.to('cuda') to use the default device.\".format(\n                        name, param.type()))\n\n        # Backward compatibility for PyTorch 0.4\n        if hasattr(model, 'named_buffers'):\n            buf_iter = model.named_buffers()\n        else:\n            buf_iter = model._buffers\n        for obj in buf_iter:\n            if type(obj)==tuple:\n                name, buf = obj\n            else:\n                name, buf = obj, buf_iter[obj]\n            if buf.is_floating_point():\n                if 'Half' in buf.type():\n                    warn_or_err(\"Found buffer {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you do not need to call .half() on your model\\n\"\n                        \"before passing it, no matter what optimization level you choose.\".format(\n                        name, buf.type()))\n                elif not buf.is_cuda:\n                    warn_or_err(\"Found buffer {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you need to provide a model with buffers\\n\"\n                        \"located on a CUDA device before passing it no matter what optimization level\\n\"\n                        \"you chose. Use model.to('cuda') to use the default device.\".format(\n                        name, buf.type()))\n\n\ndef check_optimizers(optimizers):\n    for optim in optimizers:\n        bad_optim_type = None\n        if isinstance(optim, FP16_Optimizer_general):\n            bad_optim_type = \"apex.fp16_utils.FP16_Optimizer\"\n        if isinstance(optim, FP16_Optimizer_for_fused):\n            bad_optim_type = \"apex.optimizers.FP16_Optimizer\"\n        if bad_optim_type is not None:\n            raise RuntimeError(\"An incoming optimizer is an instance of {}. \".format(bad_optim_type) +\n                               \"The optimizer(s) passed to amp.initialize() must be bare \\n\"\n                               \"instances of either ordinary Pytorch optimizers, or Apex fused \\n\"\n                               \"optimizers.\\n\")\n\n\nclass O2StateDictHook(object):\n    def __init__(self, fn):\n        self.fn = fn\n\n    def __call__(self, module, state_dict, prefix, local_metadata):\n        for key in state_dict:\n            param = state_dict[key]\n            if 'Half' in param.type():\n                param = param.to(torch.float32)\n                state_dict[key] = param\n\n\ndef _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):\n    from .amp import init as amp_init\n\n    optimizers_was_list = False\n    if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):\n        optimizers = [optimizers]\n    elif optimizers is None:\n        optimizers = []\n    elif isinstance(optimizers, list):\n        optimizers_was_list = True\n        check_optimizers(optimizers)\n    else:\n        check_optimizers([optimizers])\n        raise TypeError(\"optimizers must be either a single optimizer or a list of optimizers.\")\n\n    if isinstance(models, torch.nn.Module):\n        models_was_list = False\n        models = [models]\n    elif isinstance(models, list):\n        models_was_list = True\n    else:\n        raise TypeError(\"models must be either a single model or a list of models.\")\n\n    check_models(models)\n\n    if not _amp_state.allow_incoming_model_not_fp32:\n        check_params_fp32(models)\n\n    # In the future, when FP16_Optimizer can be deprecated and master weights can\n    # become an attribute, remember to stash master weights before casting the model.\n\n    if properties.cast_model_type:\n        if properties.keep_batchnorm_fp32:\n            for model in models:\n                convert_network(model, properties.cast_model_type)\n        else:\n            for model in models:\n                model.to(properties.cast_model_type)\n\n        input_caster = functools.partial(to_type, properties.cast_model_type)\n        if cast_model_outputs is not None:\n            output_caster = functools.partial(to_type, cast_model_outputs)\n        else:\n            output_caster = functools.partial(to_type, torch.float32)\n\n        for model in models:\n            # Patch the forward method to cast incoming data to the correct type, and\n            # outgoing data to float32, so \"the user never needs to call .half().\"\n            # I like writing things explicitly more than decorators.\n            def patch_forward(old_fwd):\n                def new_fwd(*args, **kwargs):\n                    output = old_fwd(*applier(args, input_caster),\n                                     **applier(kwargs, input_caster))\n                    return applier(output, output_caster)\n                return new_fwd\n\n            model.forward = patch_forward(model.forward)\n\n        # State dict trick to recast any preexisting per-param state tensors\n        for optimizer in optimizers:\n            optimizer.load_state_dict(optimizer.state_dict())\n\n        # patch model.state_dict() to return float32 params\n        for model in models:\n            for module in model.modules():\n                module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))\n\n    elif cast_model_outputs is not None:\n        output_caster = functools.partial(to_type, cast_model_outputs)\n\n        for model in models:\n            def patch_forward(old_fwd):\n                def new_fwd(*args, **kwargs):\n                    output = old_fwd(*args, **kwargs)\n                    return applier(output, output_caster)\n                return new_fwd\n\n            model.forward = patch_forward(model.forward)\n\n    for i, optimizer in enumerate(optimizers):\n        optimizers[i] = _process_optimizer(optimizer, properties)\n\n    _amp_state.loss_scalers = []\n    for _ in range(num_losses):\n        _amp_state.loss_scalers.append(LossScaler(properties.loss_scale,\n                                                  min_loss_scale=_amp_state.min_loss_scale,\n                                                  max_loss_scale=_amp_state.max_loss_scale))\n\n    if properties.patch_torch_functions:\n        # handle is unused here. It's accessible later through a global value anyway.\n        handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))\n        for optimizer in optimizers:\n            # Disable Amp casting for the optimizer step, because it should only be\n            # applied to FP32 master params anyway.\n            def patch_step(old_step):\n                def new_step(self, *args, **kwargs):\n                    with disable_casts():\n                        output = old_step(*args, **kwargs)\n                    return output\n                return new_step\n\n            optimizer.step = MethodType(patch_step(optimizer.step), optimizer)\n\n    if optimizers_was_list:\n        if models_was_list:\n            return models, optimizers\n        else:\n            return models[0], optimizers\n    else:\n        if models_was_list:\n            if len(optimizers) == 0:\n                return models\n            else:\n                return models, optimizers[0]\n        else:\n            if len(optimizers) == 0:\n                return models[0]\n            else:\n                return models[0], optimizers[0]\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/_process_optimizer.py",
    "content": "import types\nfrom ..fp16_utils import master_params_to_model_params\nfrom ..multi_tensor_apply import multi_tensor_applier\nfrom ._amp_state import maybe_print\nimport torch\nfrom ..optimizers import FusedSGD\n\n\nclass AmpOptimizerState(object):\n    def __init__(self):\n        pass\n\n\ndef _master_params_to_model_params(self):\n    stash = self._amp_stash\n    if multi_tensor_applier.available:\n        if len(stash.all_fp16_params) > 0:\n            multi_tensor_applier(\n                stash.multi_tensor_scale,\n                stash.dummy_overflow_buf,\n                [stash.all_fp32_from_fp16_params, stash.all_fp16_params],\n                1.0)\n    else:\n        for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):\n            master_params_to_model_params(fp16_group, fp32_from_fp16_group)\n\n\ndef lazy_init_with_master_weights(self):\n        stash = self._amp_stash\n        stash.fp16_groups = []\n        stash.fp32_from_fp16_groups = []\n        stash.fp32_from_fp32_groups = []\n        for i, param_group in enumerate(self.param_groups):\n            # maybe_print(\"FP16_Optimizer processing param group {}:\".format(i))\n            fp16_params_this_group = []\n            fp32_params_this_group = []\n            fp32_from_fp16_params_this_group = []\n            for i, param in enumerate(param_group['params']):\n                if param.requires_grad:\n                    if param.type() == 'torch.cuda.HalfTensor':\n                        # maybe_print(\"FP16_Optimizer received torch.cuda.HalfTensor with {}\"\n                        #             .format(param.size()))\n                        fp16_params_this_group.append(param)\n                        master_param = param.detach().clone().float()\n                        master_param.requires_grad = True\n                        param_group['params'][i] = master_param\n                        fp32_from_fp16_params_this_group.append(master_param)\n                        # Reset existing state dict key to the new master param.\n                        # We still need to recast per-param state tensors, if any, to FP32.\n                        if param in self.state:\n                           self.state[master_param] = self.state.pop(param)\n                    elif param.type() == 'torch.cuda.FloatTensor':\n                        # maybe_print(\"FP16_Optimizer received torch.cuda.FloatTensor with {}\"\n                        #             .format(param.size()))\n                        fp32_params_this_group.append(param)\n                        param_group['params'][i] = param\n                    else:\n                        raise TypeError(\"Optimizer's parameters must be either \"\n                                        \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                        \"Received {}\".format(param.type()))\n\n            stash.fp16_groups.append(fp16_params_this_group)\n            stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)\n            stash.fp32_from_fp32_groups.append(fp32_params_this_group)\n\n        stash.all_fp16_params = []\n        for group in stash.fp16_groups:\n            stash.all_fp16_params += group\n\n        stash.all_fp32_from_fp16_params = []\n        for group in stash.fp32_from_fp16_groups:\n            stash.all_fp32_from_fp16_params += group\n\n        stash.all_fp32_from_fp32_params = []\n        for group in stash.fp32_from_fp32_groups:\n            stash.all_fp32_from_fp32_params += group\n\n        # all_fp16_grad_stash is only needed for fused optimizers.\n        stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]\n        # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]\n        stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]\n\n        for param in stash.all_fp32_from_fp16_params:\n            param.grad = None\n\n        for param in stash.all_fp32_from_fp32_params:\n            param.grad = None\n\n        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors\n        self.load_state_dict(self.state_dict())\n\n\ndef post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):\n        grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0\n\n        # not much to do if scale == 1.0 and static scaling\n        if scaler.loss_scale() == 1.0 and not scaler.dynamic:\n            # Clear the stash.\n            for i in range(len(stashed_grads)):\n                stashed_grads[i] = None\n            return\n        \n        if scale_override is not None:\n            grads_have_scale, stashed_have_scale, out_scale = scale_override\n\n        # This is a lot of python overhead...\n        grads_needing_unscale = []\n        grads_needing_unscale_with_stash = []\n        stashed = []\n        for param, stashed_grad in zip(params, stashed_grads):\n            if param.grad is None and stashed_grad is not None:\n                param.grad = stashed_grad\n            elif param.grad is not None and stashed_grad is None:\n                grads_needing_unscale.append(param.grad)\n            elif param.grad is not None and stashed_grad is not None:\n                grads_needing_unscale_with_stash.append(param.grad)\n                stashed.append(stashed_grad)\n            else: # param.grad is None and stashed_grad is None\n                continue\n\n        # unscale() implements grads*(1/scale), so \"scale\" should be grads_have_scale/out_scale.\n        if len(grads_needing_unscale) > 0:\n            scaler.unscale(\n                grads_needing_unscale,\n                grads_needing_unscale,\n                None, # unused_scale, currently present to avoid API breakage elsewhere\n                models_are_masters=True,\n                scale_override=grads_have_scale/out_scale)\n\n        if len(grads_needing_unscale_with_stash) > 0:\n            scaler.unscale_with_stashed(\n                grads_needing_unscale_with_stash,\n                stashed,\n                grads_needing_unscale_with_stash,\n                scale_override=(grads_have_scale, stashed_have_scale, out_scale))\n\n        # Clear the stash.\n        for i in range(len(stashed_grads)):\n            stashed_grads[i] = None\n\n\ndef prepare_backward_with_master_weights(self):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    for i, param in enumerate(stash.all_fp16_params):\n        # Set up to leverage grad copy elision.\n        # This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.\n        param.grad = None\n\n    # for i, param in enumerate(stash.all_fp32_from_fp16_params):\n    #     stash.all_fp32_from_fp16_grad_stash[i] = param.grad\n\n    for i, param in enumerate(stash.all_fp32_from_fp32_params):\n        stash.all_fp32_from_fp32_grad_stash[i] = param.grad\n        # Set up to leverage grad copy elision:\n        param.grad = None\n\n\ndef post_backward_with_master_weights(self, scaler):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    # This is a lot of python overhead...\n    fp16_grads_needing_unscale = []\n    new_fp32_grads = []\n    fp16_grads_needing_unscale_with_stash = []\n    preexisting_fp32_grads = []\n    for fp16_param, fp32_param in zip(stash.all_fp16_params,\n                                      stash.all_fp32_from_fp16_params):\n        if fp16_param.grad is None and fp32_param.grad is not None:\n            continue\n        elif fp16_param.grad is not None and fp32_param.grad is None:\n            fp32_param.grad = torch.empty_like(fp32_param)\n            fp16_grads_needing_unscale.append(fp16_param.grad)\n            new_fp32_grads.append(fp32_param.grad)\n        elif fp16_param.grad is not None and fp32_param.grad is not None:\n            fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)\n            preexisting_fp32_grads.append(fp32_param.grad)\n        else: # fp16_param.grad is None and fp32_param.grad is None:\n            continue\n\n    if len(fp16_grads_needing_unscale) > 0:\n        scaler.unscale(\n            fp16_grads_needing_unscale,\n            new_fp32_grads,\n            scaler.loss_scale(),\n            models_are_masters=False)\n\n    if len(fp16_grads_needing_unscale_with_stash) > 0:\n        scaler.unscale_with_stashed(\n            fp16_grads_needing_unscale_with_stash,\n            preexisting_fp32_grads,\n            preexisting_fp32_grads)\n\n    # fp32 params can be treated as they would be in the \"no_master_weights\" case.\n    post_backward_models_are_masters(\n        scaler,\n        stash.all_fp32_from_fp32_params,\n        stash.all_fp32_from_fp32_grad_stash)\n\n\ndef lazy_init_no_master_weights(self):\n    stash = self._amp_stash\n    stash.all_fp16_params = []\n    stash.all_fp32_params = []\n    for i, param_group in enumerate(self.param_groups):\n        for i, param in enumerate(param_group['params']):\n            if param.type() == 'torch.cuda.HalfTensor':\n                stash.all_fp16_params.append(param)\n            elif param.type() == 'torch.cuda.FloatTensor':\n                stash.all_fp32_params.append(param)\n            else:\n                raise TypeError(\"Optimizer's parameters must be either \"\n                                \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                \"Received {}\".format(param.type()))\n\n    stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]\n    stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]\n\n\ndef prepare_backward_no_master_weights(self):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    for i, param in enumerate(stash.all_fp16_params):\n        stash.all_fp16_grad_stash[i] = param.grad\n        # Set up to leverage grad copy elision:\n        param.grad = None\n\n    for i, param in enumerate(stash.all_fp32_params):\n        stash.all_fp32_grad_stash[i] = param.grad\n        # Set up to leverage grad copy elision:\n        param.grad = None\n\n\ndef post_backward_no_master_weights(self, scaler):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),\n             (stash.all_fp32_params, stash.all_fp32_grad_stash))\n\n    for params, stashed_grads in split_types:\n        post_backward_models_are_masters(scaler, params, stashed_grads)\n\n\n#####################################################################################\n# FusedSGD versions\n#####################################################################################\n\n# FusedSGD never explicitly materializes the fp32 gradients for \"fp32 from fp16\" master params\n# outside the kernel, so we must accumulate directly into the model grads.\ndef prepare_backward_with_master_weights_FusedSGD(self):\n    if self.materialize_master_grads:\n        prepare_backward_with_master_weights(self)\n    else:\n        stash = self._amp_stash\n\n        self._amp_lazy_init()\n\n        for i, param in enumerate(stash.all_fp16_params):\n            stash.all_fp16_grad_stash[i] = param.grad\n            # Set up to leverage grad copy elision:\n            param.grad = None\n\n        for i, param in enumerate(stash.all_fp32_from_fp32_params):\n            stash.all_fp32_from_fp32_grad_stash[i] = param.grad\n            # Set up to leverage grad copy elision:\n            param.grad = None\n\n\ndef post_backward_with_master_weights_FusedSGD(self, scaler):\n    if self.materialize_master_grads:\n        post_backward_with_master_weights(self, scaler)\n    else:\n        stash = self._amp_stash\n\n        self._amp_lazy_init()\n\n        grads_have_scale = scaler.loss_scale()\n        stashed_have_scale = self.most_recent_scale\n        out_scale = grads_have_scale\n        if self.scale_set_by_backward:\n            out_scale = min(grads_have_scale, self.most_recent_scale)\n\n        split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),\n                 (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))\n\n\n        # unscale_with_stashed() implements grads*1/scale + stashed_grads*1.\n        # stashed_grads are scaled by self.most_recent_scale.\n        for params, stashed_grads in split_types:\n            post_backward_models_are_masters(scaler, params, stashed_grads,\n                                             (grads_have_scale, stashed_have_scale, out_scale))\n\n        self.most_recent_scale = out_scale\n        self.scale_set_by_backward = True\n\n\ndef prepare_backward_no_master_weights_FusedSGD(self):\n    prepare_backward_no_master_weights(self)\n\n\ndef post_backward_no_master_weights_FusedSGD(self, scaler):\n    post_backward_no_master_weights(self, scaler)\n\n\ndef _amp_lazy_init(self):\n    stash = self._amp_stash\n\n    if not stash.lazy_init_called:\n        self._lazy_init_maybe_master_weights()\n        stash.lazy_init_called = True\n\n\ndef _process_optimizer(optimizer, properties):\n    if hasattr(optimizer, \"_amp_stash\"):\n        raise RuntimeError(\"A given optimizer should only be passed through amp.initialize once.\")\n    else:\n        optimizer._amp_stash = AmpOptimizerState()\n\n    optimizer._amp_stash.lazy_init_called = False\n    optimizer._amp_stash.already_patched = False\n    optimizer._amp_stash.params_have_scaled_gradients = False\n\n    for name in (\"_lazy_init_maybe_master_weights\",\n                 \"_master_params_to_model_params\",\n                 \"_prepare_amp_backward\",\n                 \"_post_amp_backward\",\n                 \"_amp_lazy_init\"):\n        if hasattr(optimizer, name):\n            raise RuntimeError(\"Incoming optimizer already has {} defined.\".format(name))\n\n    # TODO:  Centralize exposure and import error checking for the C backend.\n    if multi_tensor_applier.available:\n        import amp_C\n        optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale\n        optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n        optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);\n\n    if properties.master_weights:\n        optimizer._lazy_init_maybe_master_weights = types.MethodType(\n            lazy_init_with_master_weights, optimizer)\n\n        optimizer._master_params_to_model_params = types.MethodType(\n            _master_params_to_model_params, optimizer)\n\n        old_step = optimizer.step\n        def new_step(self, closure=None):\n            if closure is not None:\n                raise RuntimeError(\"Currently, Amp does not support closure use with optimizers.\")\n            retval = old_step()\n            if not isinstance(self, FusedSGD):\n                self._master_params_to_model_params()\n            # Clear the master grads that wouldn't be zeroed by model.zero_grad()\n            for param in self._amp_stash.all_fp32_from_fp16_params:\n                param.grad = None\n            return retval\n        optimizer.step = types.MethodType(new_step, optimizer)\n\n        old_zero_grad = optimizer.zero_grad\n        def new_zero_grad(self):\n            stash = self._amp_stash\n            self._amp_lazy_init()\n            # Zero the model grads.\n            for param in stash.all_fp16_params:\n                if param.grad is not None:\n                    param.grad.detach_()\n                    param.grad.zero_()\n            for param in stash.all_fp32_from_fp32_params:\n                if param.grad is not None:\n                    param.grad.detach_()\n                    param.grad.zero_()\n            # Clear the master grads that are independent of model grads\n            for param in self._amp_stash.all_fp32_from_fp16_params:\n                param.grad = None\n        optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)\n\n        if isinstance(optimizer, FusedSGD):\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_with_master_weights_FusedSGD, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_with_master_weights_FusedSGD, optimizer)\n        else:\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_with_master_weights, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_with_master_weights, optimizer)\n    else:\n        optimizer._lazy_init_maybe_master_weights = types.MethodType(\n            lazy_init_no_master_weights, optimizer)\n\n        if isinstance(optimizer, FusedSGD):\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_no_master_weights_FusedSGD, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_no_master_weights_FusedSGD, optimizer)\n        else:\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_no_master_weights, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_no_master_weights, optimizer)\n\n    optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)\n\n    old_add_param_group = optimizer.add_param_group\n\n    def new_add_param_group(self, new_group):\n        stash = self._amp_stash\n\n        if not stash.lazy_init_called:\n            self._lazy_init_maybe_master_weights()\n            stash.lazy_init_called = True\n\n        assert isinstance(new_group, dict), \"param group must be a dict\"\n\n        new_params = new_group['params']\n        if isinstance(new_params, torch.Tensor):\n            new_group['params'] = [new_params]\n        elif isinstance(new_params, set):\n            raise TypeError('optimizer parameters need to be organized in ordered collections, but '\n                            'the ordering of tensors in sets will change between runs. Please use a list instead.')\n        else:\n            new_group['params'] = list(new_params)\n\n        if properties.master_weights:\n            # Mutate new_group in-place to use FP32 master params\n            fp16_params_this_group = []\n            fp32_params_this_group = []\n            fp32_from_fp16_params_this_group = []\n            for i, param in enumerate(new_group['params']):\n                if param.requires_grad:\n                    if param.type() == 'torch.cuda.HalfTensor':\n                        fp16_params_this_group.append(param)\n                        master_param = param.detach().clone().float()\n                        master_param.requires_grad = True\n                        new_group['params'][i] = master_param\n                        fp32_from_fp16_params_this_group.append(master_param)\n                    elif param.type() == 'torch.cuda.FloatTensor':\n                        fp32_params_this_group.append(param)\n                        new_group['params'][i] = param\n                    else:\n                        raise TypeError(\"Optimizer's parameters must be either \"\n                                        \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                        \"Received {}\".format(param.type()))\n\n            stash.fp16_groups.append(fp16_params_this_group)\n            stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)\n            stash.fp32_from_fp32_groups.append(fp32_params_this_group)\n\n            stash.all_fp16_params += fp16_params_this_group\n            stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group\n            stash.all_fp32_from_fp32_params += fp32_params_this_group\n\n            # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]\n            stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]\n\n            # It should be ok to let params be added with existing .grad attributes.\n            # for param in fp16_params_this_group:\n            #     param.grad = None\n\n            # for param in fp32_from_fp16_params_this_group:\n            #     param.grad = None\n\n            # for param in stash.fp32_params_this_group:\n            #     param.grad = None\n        else:\n            for param in new_group['params']:\n                if param.type() == 'torch.cuda.HalfTensor':\n                    stash.all_fp16_params.append(param)\n                    stash.all_fp16_grad_stash.append(None)\n                elif param.type() == 'torch.cuda.FloatTensor':\n                    stash.all_fp32_params.append(param)\n                    stash.all_fp32_grad_stash.append(None)\n                else:\n                    raise TypeError(\"Optimizer's parameters must be either \"\n                                    \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                    \"Received {}\".format(param.type()))\n\n        old_add_param_group(new_group)\n\n    optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)\n\n    return optimizer\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/amp.py",
    "content": "from . import compat, rnn_compat, utils, wrap\nfrom .handle import AmpHandle, NoOpHandle\nfrom .lists import functional_overrides, torch_overrides, tensor_overrides\nfrom ._amp_state import _amp_state\nfrom .frontend import *\n\nimport functools\nimport itertools\n\nimport torch\n\n\n_DECORATOR_HANDLE = None\n_USER_CAST_REGISTRY = set()\n_USER_PROMOTE_REGISTRY = set()\n\n\ndef _decorator_helper(orig_fn, cast_fn, wrap_fn):\n    def wrapper(*args, **kwargs):\n        handle = _DECORATOR_HANDLE\n        if handle is None or not handle.is_active():\n            return orig_fn(*args, **kwargs)\n        inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,\n                                  handle.verbose)\n        return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)\n    return wrapper\n\n\n# Decorator form\ndef half_function(fn):\n    wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)\n    return _decorator_helper(fn, utils.maybe_half, wrap_fn)\n\n\ndef float_function(fn):\n    wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)\n    return _decorator_helper(fn, utils.maybe_float, wrap_fn)\n\n\ndef promote_function(fn):\n    wrap_fn = functools.partial(wrap.make_promote_wrapper)\n    return _decorator_helper(fn, utils.maybe_float, wrap_fn)\n\n\n# Registry form\ndef register_half_function(module, name):\n    if not hasattr(module, name):\n        raise ValueError('No function named {} in module {}.'.format(\n            name, module))\n    _USER_CAST_REGISTRY.add((module, name, utils.maybe_half))\n\n\ndef register_float_function(module, name):\n    if not hasattr(module, name):\n        raise ValueError('No function named {} in module {}.'.format(\n            name, module))\n    _USER_CAST_REGISTRY.add((module, name, utils.maybe_float))\n\n\ndef register_promote_function(module, name):\n    if not hasattr(module, name):\n        raise ValueError('No function named {} in module {}.'.format(\n            name, module))\n    _USER_PROMOTE_REGISTRY.add((module, name))\n\n\n# Top-level function to insert _all_ the hooks.\ndef init(enabled=True, loss_scale=\"dynamic\", enable_caching=True, verbose=False, allow_banned=False):\n    global _DECORATOR_HANDLE\n\n    if not enabled:\n        handle = NoOpHandle()\n        _DECORATOR_HANDLE = handle\n        return handle\n\n    handle = AmpHandle(loss_scale, enable_caching, verbose)\n\n    # 0) Force-{fp16, fp32} for user-annotated functions\n    for mod, fn, cast_fn in _USER_CAST_REGISTRY:\n        try_caching = (cast_fn == utils.maybe_half)\n        wrap.cached_cast(mod, fn, cast_fn, handle,\n                         try_caching, verbose)\n    _USER_CAST_REGISTRY.clear()\n\n    # 0.5) Force-promote for user-annotated functions\n    for mod, fn in _USER_PROMOTE_REGISTRY:\n        wrap.promote(mod, fn, handle, verbose)\n    _USER_PROMOTE_REGISTRY.clear()\n\n    # 1) Force-{fp16, fp32} on white- / black-list functions\n    override_modules = [functional_overrides,\n                        torch_overrides,\n                        tensor_overrides]\n    cast_table = [('FP16_FUNCS', utils.maybe_half),\n                  ('FP32_FUNCS', utils.maybe_float)]\n    for module, (list_name, cast_fn) in itertools.product(override_modules,\n                                                          cast_table):\n        for fn in getattr(module, list_name):\n            try_caching = (cast_fn == utils.maybe_half)\n            wrap.cached_cast(module.MODULE, fn, cast_fn, handle,\n                             try_caching, verbose)\n\n    # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist\n    #      methods on FloatTensor, since they're distinct types.\n    if compat.tensor_is_float_tensor():\n        for fn in tensor_overrides.FP16_FUNCS:\n            wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,\n                             handle, try_caching=True, verbose=verbose)\n        for fn in tensor_overrides.FP32_FUNCS:\n            wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,\n                             handle, try_caching=False, verbose=verbose)\n\n    # 2) Enable type-promotion on multi-arg functions and methods.\n    #    NB: special handling for sequence fns (e.g. `torch.cat`).\n    promote_modules = [torch_overrides, tensor_overrides]\n    promote_table = [('CASTS', wrap.promote),\n                     ('SEQUENCE_CASTS', wrap.sequence_promote)]\n    for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,\n                                                                  promote_table):\n        for fn in getattr(promote_mod, list_name):\n            promote_fn(promote_mod.MODULE, fn, handle, verbose)\n\n    # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types\n    if compat.tensor_is_float_tensor():\n        for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,\n                                                               torch.cuda.HalfTensor],\n                                                              promote_table):\n            for fn in getattr(tensor_overrides, list_name):\n                promote_fn(cls, fn, handle, verbose)\n\n    # 3) For any in-place version of a blacklist function, error if any input is fp16.\n    #    NB: this is overly conservative.\n    for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):\n        wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)\n\n    # 3.5) For any in-place blacklist method, error if called on fp16 tensor\n    for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):\n        wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)\n        if compat.tensor_is_float_tensor():\n            wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)\n\n    # 4) For other in-place methods, match the type of self tensor\n    for fn in utils.as_inplace(itertools.chain(\n            tensor_overrides.FP16_FUNCS,\n            tensor_overrides.CASTS)):\n        wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)\n        if compat.tensor_is_float_tensor():\n            wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)\n            wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)\n\n    # 5) RNNs + RNN cells are whitelisted specially\n    if rnn_compat.has_old_rnns():\n        wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)\n    if not rnn_compat.has_old_rnns():\n        # Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.\n        torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()\n        # Wrap all the rnns\n        for x in rnn_compat.RNN_NAMES:\n            wrap.new_rnn_cast(x.upper(), handle, verbose)\n\n    # Wrap all the RNN cells\n    rnn_compat.whitelist_rnn_cells(handle, verbose)\n\n    # 6) Place error+print message on banned functions.\n    #    Or, if allow_banned, then cast to FP32.\n    for fn, err_msg in functional_overrides.BANNED_FUNCS:\n        if allow_banned:\n            wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,\n                             handle, try_caching=True, verbose=verbose)\n        else:\n            wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)\n\n    _DECORATOR_HANDLE = handle\n\n    _amp_state.handle = handle\n\n    return handle\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/compat.py",
    "content": "import torch\n\n# True for post-0.4, when Variables/Tensors merged.\ndef variable_is_tensor():\n    v = torch.autograd.Variable()\n    return isinstance(v, torch.Tensor)\n\ndef tensor_is_variable():\n    x = torch.Tensor()\n    return type(x) == torch.autograd.Variable\n\n# False for post-0.4\ndef tensor_is_float_tensor():\n    x = torch.Tensor()\n    return type(x) == torch.FloatTensor\n\n# Akin to `torch.is_tensor`, but returns True for Variable\n# objects in pre-0.4.\ndef is_tensor_like(x):\n    return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)\n\n# Wraps `torch.is_floating_point` if present, otherwise checks\n# the suffix of `x.type()`.\ndef is_floating_point(x):\n    if hasattr(torch, 'is_floating_point'):\n        return torch.is_floating_point(x)\n    try:\n        torch_type = x.type()\n        return torch_type.endswith('FloatTensor') or \\\n            torch_type.endswith('HalfTensor') or \\\n            torch_type.endswith('DoubleTensor')\n    except AttributeError:\n        return False\n\ndef scalar_python_val(x):\n    if hasattr(x, 'item'):\n        return x.item()\n    else:\n        if isinstance(x, torch.autograd.Variable):\n            return x.data[0]\n        else:\n            return x[0]\n\n# Accounts for the possibility that some ops may be removed from a namespace.\ndef filter_attrs(module, attrs):\n    return list(attrname for attrname in attrs if hasattr(module, attrname))\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/frontend.py",
    "content": "import torch\nfrom ._initialize import _initialize\nfrom ._amp_state import _amp_state, warn_or_err, maybe_print\nfrom collections import OrderedDict\n\n\nclass Properties(object):\n    \"\"\"\n    This class has two purposes: to establish a set of default properties,\n    and to route setting of these attributes through __setattr__ so that (in theory)\n    they can be checked for consistency with other existing args.\n    \"\"\"\n    def __init__(self):\n        self.options = {\n            \"enabled\" : False,\n            \"opt_level\" : None,\n            \"cast_model_type\" : None,\n            \"patch_torch_functions\" : False,\n            \"keep_batchnorm_fp32\" : None,\n            \"master_weights\" : None,\n            \"loss_scale\" : 1.0,\n            # Reserved for future functionality\n            # \"fused_optimizer\" : False,\n            # \"enable_ddp_interop\" : False,\n            }\n\n    \"\"\"\n    This function allows updating several options at a time without routing through\n    __setattr__ checks, to avoid \"you can't get there from here\" scenarios.\n    Currently not intended to be exposed; users are expected to select an opt_level\n    and apply consistent modifications.\n    \"\"\"\n    def _update_options_dict(self, new_options):\n        for k, v in new_options:\n            if k in self.options:\n                self.options[k] = v\n            else:\n                raise ValueError(\"Tried to set unexpected option {}\".format(k))\n    \"\"\"\n    The members of \"options\" are not direct attributes of self, so access attempts\n    will roll down to __getattr__.  This borrows from the logic in torch.nn.Module.\n    \"\"\"\n    def __getattr__(self, name):\n        if \"options\" in self.__dict__:\n            options =  self.__dict__[\"options\"]\n            if name in options:\n                return options[name]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, name))\n\n    def __setattr__(self, name, value):\n        if \"options\" in self.__dict__:\n            if name in self.options:\n                # print(\"setting {} {}\".format(name, value))\n                if name == \"cast_model_type\":\n                    if self.opt_level == \"O1\" and value is not None:\n                        if value is not False:\n                            if value is not torch.float32:\n                                warn_or_err(\"O1 inserts casts around Torch functions rather than \"\n                                            \"model weights, so with O1, the model weights themselves \"\n                                            \"should remain FP32. If you wish to cast the model to a \"\n                                            \"different type, use opt_level='O2' or 'O3'. \" +\n                                            \"cast_model_type was {}\".format(value))\n                    self.options[name] = value\n                elif name == \"patch_torch_functions\":\n                    if self.opt_level != \"O1\" and value:\n                        warn_or_err(\"Currently, patch_torch_functions=True should only be set by \"\n                                    \"selecting opt_level='O1'.\")\n                    self.options[name] = value\n                elif name == \"keep_batchnorm_fp32\":\n                    if self.opt_level == \"O1\" and value is not None:\n                        warn_or_err(\"With opt_level O1, batchnorm functions are automatically patched \"\n                                    \"to run in FP32, so keep_batchnorm_fp32 should be None.\" +\n                                    \" keep_batchnorm_fp32 was {}\".format(value))\n                    if value == \"False\":\n                        self.options[name] = False\n                    elif value == \"True\":\n                        self.options[name] = True\n                    else:\n                        assert (value is True or value is False or value is None),\\\n                            \"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', \"\\\n                            \"or None, found keep_batchnorm_fp32={}\".format(value)\n                        self.options[name] = value\n                elif name == \"master_weights\":\n                    if self.opt_level == \"O1\" and value is not None:\n                        warn_or_err(\"It doesn't make sense to use master_weights with O1. \"\n                                    \"With O1, your model weights themselves should be FP32.\")\n                    self.options[name] = value\n                elif name == \"loss_scale\":\n                    if value == \"dynamic\":\n                        self.options[name] = value\n                    else:\n                        self.options[name] = float(value)\n                else:\n                    self.options[name] = value\n        else:\n            super(Properties, self).__setattr__(name, value)\n\n\n\"\"\" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. \"\"\"\n\nclass O3:\n    brief = \"O3:  Pure FP16 training.\"\n    more = \"Calls .half() on your model, converting the entire model to FP16.\\n\"\\\n        \"A casting operation is also inserted to cast incoming Tensors to FP16,\\n\"\\\n        \"so you don't need to change your data pipeline.\\n\"\\\n        \"This mode is useful for establishing a performance ceiling.\\n\"\\\n        \"It's also possible training may 'just work' in this mode.\\n\"\\\n        \"If not, try other optimization levels.\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O3\"\n        properties.cast_model_type = torch.float16\n        properties.patch_torch_functions = False\n        properties.keep_batchnorm_fp32 = False\n        properties.master_weights = False\n        properties.loss_scale = 1.0\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nclass O2:\n    brief = \"O2:  FP16 training with FP32 batchnorm and FP32 master weights.\\n\"\n    more = \"Calls .half() on your model, converting the entire model (except for batchnorms)\\n\"\\\n        \"to FP16.  Batchnorms are retained in FP32 for additional stability.\\n\"\\\n        \"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\\n\"\\\n        \"your data pipeline.\\n\"\\\n        \"O2 creates FP32 master weights outside the model and patches any optimizers to update\\n\"\\\n        \"these master weights, then copy the master weights into the FP16 model weights.\\n\"\\\n        \"Master weights can also improve convergence and stability.\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O2\"\n        properties.cast_model_type = torch.float16\n        properties.patch_torch_functions = False\n        properties.keep_batchnorm_fp32 = True\n        properties.master_weights = True\n        properties.loss_scale = \"dynamic\"\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nclass O1:\n    brief = \"O1:  Insert automatic casts around Pytorch functions and Tensor methods.\\n\"\n    more = \"The type of your model's weights is not altered.  However, internally,\\n\"\\\n        \"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\\n\"\\\n        \"while operations that might benefit from the additional stability of FP32 are patched\\n\"\\\n        \"to cast their inputs to fp32.\\n\"\\\n        \"O1 is the safest way to try mixed precision training, and is recommended when\\n\"\\\n        \"trying mixed precision training for the first time.\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O1\"\n        properties.cast_model_type = None\n        properties.patch_torch_functions = True\n        properties.keep_batchnorm_fp32 = None\n        properties.master_weights = None\n        properties.loss_scale = \"dynamic\"\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nclass O0:\n    brief = \"O0:  Pure FP32 training.\\n\"\n    more = \"Your models are checked to make sure parameters are FP32, but otherwise the\\n\"\\\n        \"types of weights and internal Pytorch operations are not altered.  This mode disables any\\n\"\\\n        \"FP16 arithmetic, although other optimizations like DDP interop may still be requested.\\n\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O0\"\n        properties.cast_model_type = torch.float32\n        properties.patch_torch_functions = False\n        properties.keep_batchnorm_fp32 = None\n        properties.master_weights = False\n        properties.loss_scale = 1.0\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nopt_levels = {\"O3\": O3(),\n              \"O2\": O2(),\n              \"O1\": O1(),\n              \"O0\": O0()}\n\n\n# allow user to directly pass Properties struct as well?\ndef initialize(\n    models,\n    optimizers=None,\n    enabled=True,\n    opt_level=\"O1\",\n    cast_model_type=None,\n    patch_torch_functions=None,\n    keep_batchnorm_fp32=None,\n    master_weights=None,\n    loss_scale=None,\n    cast_model_outputs=None,\n    num_losses=1,\n    verbosity=1,\n    min_loss_scale=None,\n    max_loss_scale=2.**24\n    ):\n    \"\"\"\n    Initialize your models, optimizers, and the Torch tensor and functional namespace according to the\n    chosen ``opt_level`` and overridden properties, if any.\n\n    ``amp.initialize`` should be called **after** you have finished\n    constructing your model(s) and\n    optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.\n    See `Distributed training`_ in the Imagenet example.\n\n    Currently, ``amp.initialize`` should only be called **once**,\n    although it can process an arbitrary number of\n    models and optimizers (see the corresponding `Advanced Amp Usage topic`_).\n    If you think your use case requires ``amp.initialize`` to be called more than once,\n    `let us know`_.\n\n    Any property keyword argument that is not ``None`` will be interpreted as a manual override.\n\n    To prevent having to rewrite anything else in your script, name the returned models/optimizers\n    to replace the passed models/optimizers, as in the code sample below.\n\n    Args:\n        models (torch.nn.Module or list of torch.nn.Modules):  Models to modify/cast.\n        optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers):  Optimizers to modify/cast.\n            REQUIRED for training, optional for inference.\n        enabled (bool, optional, default=True):  If False, renders all Amp calls no-ops, so your script\n            should run as if Amp were not present.\n        opt_level (str, optional, default=\"O1\"):  Pure or mixed precision optimization level.  Accepted values are\n            \"O0\", \"O1\", \"O2\", and \"O3\", explained in detail above.\n        cast_model_type (``torch.dtype``, optional, default=None):  Optional property override, see\n            above.\n        patch_torch_functions (bool, optional, default=None):  Optional property override.\n        keep_batchnorm_fp32 (bool or str, optional, default=None):  Optional property override.  If\n            passed as a string, must be the string \"True\" or \"False\".\n        master_weights (bool, optional, default=None):  Optional property override.\n        loss_scale (float or str, optional, default=None):  Optional property override.  If passed as a string,\n            must be a string representing a number, e.g., \"128.0\", or the string \"dynamic\".\n        cast_model_outputs (torch.dtype, optional, default=None):  Option to ensure that the outputs\n            of your model(s) are always cast to a particular type regardless of ``opt_level``.\n        num_losses (int, optional, default=1):  Option to tell Amp in advance how many losses/backward\n            passes you plan to use.  When used in conjunction with the ``loss_id`` argument to\n            ``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,\n            which can improve stability.  See \"Multiple models/optimizers/losses\"\n            under `Advanced Amp Usage`_ for examples.  If ``num_losses`` is left to 1, Amp will still\n            support multiple losses/backward passes, but use a single global loss scale\n            for all of them.\n        verbosity (int, default=1):  Set to 0 to suppress Amp-related output.\n        min_loss_scale (float, default=None):  Sets a floor for the loss scale values that can be chosen by dynamic\n            loss scaling.  The default value of None means that no floor is imposed.\n            If dynamic loss scaling is not used, `min_loss_scale` is ignored.\n        max_loss_scale (float, default=2.**24):  Sets a ceiling for the loss scale values that can be chosen by\n            dynamic loss scaling.  If dynamic loss scaling is not used, `max_loss_scale` is ignored.\n\n    Returns:\n        Model(s) and optimizer(s) modified according to the ``opt_level``.\n        If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will\n        also be a list.\n\n    Permissible invocations::\n\n        model, optim = amp.initialize(model, optim,...)\n        model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)\n        [model1, model2], optim = amp.initialize([model1, model2], optim,...)\n        [model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)\n\n        # This is not an exhaustive list of the cross product of options that are possible,\n        # just a set of examples.\n        model, optim = amp.initialize(model, optim, opt_level=\"O0\")\n        model, optim = amp.initialize(model, optim, opt_level=\"O0\", loss_scale=\"dynamic\"|128.0|\"128.0\")\n\n        model, optim = amp.initialize(model, optim, opt_level=\"O1\") # uses \"loss_scale=\"dynamic\" default\n        model, optim = amp.initialize(model, optim, opt_level=\"O1\", loss_scale=128.0|\"128.0\")\n\n        model, optim = amp.initialize(model, optim, opt_level=\"O2\") # uses \"loss_scale=\"dynamic\" default\n        model, optim = amp.initialize(model, optim, opt_level=\"O2\", loss_scale=128.0|\"128.0\")\n        model, optim = amp.initialize(model, optim, opt_level=\"O2\", keep_batchnorm_fp32=True|False|\"True\"|\"False\")\n\n        model, optim = amp.initialize(model, optim, opt_level=\"O3\") # uses loss_scale=1.0 default\n        model, optim = amp.initialize(model, optim, opt_level=\"O3\", loss_scale=\"dynamic\"|128.0|\"128.0\")\n        model, optim = amp.initialize(model, optim, opt_level=\"O3\", keep_batchnorm_fp32=True|False|\"True\"|\"False\")\n\n    The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.\n\n    .. _`Distributed training`:\n        https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training\n\n    .. _`Imagenet example`:\n        https://github.com/NVIDIA/apex/tree/master/examples/imagenet\n\n    .. _`Advanced Amp Usage`:\n        https://nvidia.github.io/apex/advanced.html\n\n    .. _`Advanced Amp Usage topic`:\n        https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses\n\n    .. _`let us know`:\n        https://github.com/NVIDIA/apex/issues\n    \"\"\"\n    _amp_state.opt_properties = Properties()\n    _amp_state.verbosity = verbosity\n\n    if not enabled:\n        if optimizers is None:\n            return models\n        else:\n            return models, optimizers\n\n    if not torch.backends.cudnn.enabled:\n        raise RuntimeError(\n            \"Amp requires torch.backends.cudnn.enabled = True\")\n\n    if opt_level not in opt_levels:\n        raise RuntimeError(\n            \"Unexpected optimization level {}. \".format(opt_level) +\n            \"Options are 'O0', 'O1', 'O2', 'O3'.  Note that in `O0`, `O1`, etc., the prefix O is the letter O, \" +\n            \"not the number zero.\")\n    else:\n        _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)\n        maybe_print(\"Selected optimization level {}\".format(opt_levels[opt_level].brief), True)\n        maybe_print(\"Defaults for this optimization level are:\", True)\n        for k, v in _amp_state.opt_properties.options.items():\n            maybe_print(\"{:22} : {}\".format(k, v), True)\n\n    _amp_state.min_loss_scale = min_loss_scale\n    _amp_state.max_loss_scale = max_loss_scale\n\n    maybe_print(\"Processing user overrides (additional kwargs that are not None)...\", True)\n    # I chose to have the keyword arguments listed directly in the argument list,\n    # instead of **kwargs, so I can't use kwargs.items() here.\n    if enabled is not None:\n        _amp_state.opt_properties.enabled = enabled\n    if opt_level is not None:\n        _amp_state.opt_properties.opt_level = opt_level\n    if cast_model_type is not None:\n        _amp_state.opt_properties.cast_model_type = cast_model_type\n    if patch_torch_functions is not None:\n        _amp_state.opt_properties.patch_torch_functions = patch_torch_functions\n    if keep_batchnorm_fp32 is not None:\n        _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32\n    if master_weights is not None:\n        _amp_state.opt_properties.master_weights = master_weights\n    if loss_scale is not None:\n        _amp_state.opt_properties.loss_scale = loss_scale\n\n    maybe_print(\"After processing overrides, optimization options are:\", True)\n    for k, v in _amp_state.opt_properties.options.items():\n        maybe_print(\"{:22} : {}\".format(k, v), True)\n\n    return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)\n\n\ndef state_dict(destination=None):\n    if destination is None:\n        destination = OrderedDict()\n\n    for idx, loss_scaler in enumerate(_amp_state.loss_scalers):\n        destination['loss_scaler%d' % idx] = {\n            'loss_scale': loss_scaler.loss_scale(),\n            'unskipped': loss_scaler._unskipped,\n        }\n    return destination\n\n\ndef load_state_dict(state_dict):\n    # Check if state_dict containes the same number of loss_scalers as current setup\n    if len(state_dict) != len(_amp_state.loss_scalers):\n        print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format(\n            len(state_dict), len(_amp_state.loss_scalers)))\n\n    state_dict = state_dict.copy()\n    \n    nb_loss_scalers = len(_amp_state.loss_scalers)\n    unexpected_keys = []\n    # Initialize idx outside, since unexpected_keys will increase it if enumerate is used\n    idx = 0\n    for key in state_dict:\n        if 'loss_scaler' not in key:\n            unexpected_keys.append(key)\n        else:\n            if idx > (nb_loss_scalers - 1):\n                print('Skipping loss_scaler[{}], since num_losses was set to {}'.format(\n                    idx, nb_loss_scalers))\n                break\n            _amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale']\n            _amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped']\n            idx += 1\n\n    if len(unexpected_keys) > 0:\n        raise RuntimeError(\n            'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format(\n                ', '.join('\"{}\"'.format(k) for k in unexpected_keys)))\n\n\n# TODO:  is this necessary/useful?\n# def check_option_consistency(enabled=True,\n#                              opt_level=None,\n#                              cast_model_type=None,\n#                              patch_torch_functions=None,\n#                              keep_batchnorm_fp32=None,\n#                              master_weights=None,\n#                              loss_scale=None,\n#                              enable_ddp_interop=None,\n#                              hard_override=False):\n#     \"\"\"\n#     Utility function that enables users to quickly check if the option combination they intend\n#     to use is permitted.  ``check_option_consistency`` does not require models or optimizers\n#     to be constructed, and can be called at any point in the script.  ``check_option_consistency``\n#     is totally self-contained; it does not set any amp global state or affect anything outside\n#     of itself.\n#     \"\"\"\n#\n#     if not enabled:\n#         return\n#\n#     if opt_level not in opt_levels:\n#         raise RuntimeError(\"Unexpected optimization level.  Options are 'O0', 'O1', 'O2', 'O3'.\")\n#     else:\n#         opt_properties = opt_levels[opt_level](Properties())\n#         print(\"Selected optimization level {}\", opt_levels[opt_level].brief)\n#         print(\"Defaults for this optimization level are:\")\n#         for k, v in opt_properties.options:\n#             print(\"{:22} : {}\".format(k, v))\n#\n#     print(\"Processing user overrides (additional kwargs that are not None)...\")\n#     for k, v in kwargs:\n#         if k not in _amp_state.opt_properties.options:\n#             raise RuntimeError(\"Unexpected kwarg {}\".format(k))\n#         if v is not None:\n#             setattr(opt_properties, k, v)\n#\n#     print(\"After processing overrides, optimization options are:\")\n#     for k, v in opt_properties.options:\n#         print(\"{:22} : {}\".format(k, v))\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/handle.py",
    "content": "import contextlib\nimport warnings\nimport sys\nimport torch\n\nfrom . import utils\nfrom .opt import OptimWrapper\nfrom .scaler import LossScaler\nfrom ._amp_state import _amp_state, master_params, maybe_print\n\nif torch.distributed.is_available():\n    from ..parallel.LARC import LARC\n\n\n# There's no reason to expose the notion of a \"handle\". Everything can happen through amp.* calls.\n@contextlib.contextmanager\ndef scale_loss(loss,\n               optimizers,\n               loss_id=0,\n               model=None,\n               delay_unscale=False,\n               delay_overflow_check=False):\n    \"\"\"\n    On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.\n    ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::\n\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n\n    On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs\n    and unscaled, so that ``optimizer.step()`` can be called.\n\n    .. note::\n        If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and\n        can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)\n        any FP16 gradients are copied to FP32 master gradients before being unscaled.\n        ``optimizer.step()`` will then apply the unscaled master gradients to the master params.\n\n    .. warning::\n        If Amp is using explicit FP32 master params, only the FP32 master gradients will be\n        unscaled.  The direct ``.grad`` attributes of any FP16\n        model params will remain scaled after context manager exit.\n        This subtlety affects gradient clipping.  See \"Gradient clipping\" under\n        `Advanced Amp Usage`_ for best practices.\n\n    Args:\n        loss(Tensor):  Typically a scalar Tensor. The ``scaled_loss`` that the context\n            manager yields is simply ``loss.float()*loss_scale``, so in principle\n            ``loss`` could have more than one element, as long as you call\n            ``backward()`` on ``scaled_loss`` appropriately within the context manager body.\n        optimizers:  All optimizer(s) for which the current backward pass is creating gradients.\n            Must be an optimizer or list of optimizers returned from an earlier call\n            to ``amp.initialize``.  For example use with multiple optimizers, see\n            \"Multiple models/optimizers/losses\" under `Advanced Amp Usage`_.\n        loss_id(int, optional, default=0):  When used in conjunction with the ``num_losses`` argument\n            to ``amp.initialize``, enables Amp to use a different loss scale per loss.  ``loss_id``\n            must be an integer between 0 and ``num_losses`` that tells Amp which loss is\n            being used for the current backward pass.  See \"Multiple models/optimizers/losses\"\n            under `Advanced Amp Usage`_ for examples.  If ``loss_id`` is left unspecified, Amp\n            will use the default global loss scaler for this backward pass.\n        model(torch.nn.Module, optional, default=None):  Currently unused, reserved to enable future\n            optimizations.\n        delay_unscale(bool, optional, default=False):  ``delay_unscale`` is never necessary, and\n            the default value of ``False`` is strongly recommended.\n            If ``True``, Amp will not unscale the gradients or perform model->master\n            gradient copies on context manager exit.\n            ``delay_unscale=True`` is a minor ninja performance optimization and can result\n            in weird gotchas (especially with multiple models/optimizers/losses),\n            so only use it if you know what you're doing.\n            \"Gradient accumulation across iterations\" under `Advanced Amp Usage`_\n            illustrates a situation where this CAN (but does not need to) be used.\n\n    .. warning::\n        If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be\n        called yet after context manager exit, and must wait for another, later backward context\n        manager invocation with ``delay_unscale`` left to False.\n\n    .. _`Advanced Amp Usage`:\n        https://nvidia.github.io/apex/advanced.html\n    \"\"\"\n    if not hasattr(_amp_state, \"opt_properties\"):\n        raise RuntimeError(\"Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized.  \"\n                           \"model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called \"\n                           \"before `with amp.scale_loss`.\")\n\n    if not _amp_state.opt_properties.enabled:\n        yield loss\n        return\n\n    if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):\n        optimizers = [optimizers]\n\n    loss_scaler = _amp_state.loss_scalers[loss_id]\n    loss_scale = loss_scaler.loss_scale()\n\n    if ((not _amp_state.opt_properties.master_weights)\n        and (not loss_scaler.dynamic)\n        and loss_scale == 1.0):\n        yield loss.float()\n        # Needing to drop the cache here as well is an ugly gotcha.\n        # But for now I think it's necessary to short-circuit.\n        # Probably ok to skip this if not delay_unscale\n        if _amp_state.opt_properties.patch_torch_functions:\n            _amp_state.handle._clear_cache()\n        return\n\n    if not delay_unscale:\n        if isinstance(optimizers, list):\n            for optimizer in optimizers:\n                if not optimizer._amp_stash.params_have_scaled_gradients:\n                    optimizer._prepare_amp_backward()\n\n    yield (loss.float())*loss_scale\n\n    if delay_unscale:\n        for optimizer in optimizers:\n            optimizer._amp_stash.params_have_scaled_gradients = True\n    else:\n        # FusedSGD may take care of unscaling as part of their step() methods.\n        # if not isinstance(optimizers, FP16_Optimizer_for_fused):\n            loss_scaler.clear_overflow_state()\n            for optimizer in optimizers:\n                optimizer._post_amp_backward(loss_scaler)\n                optimizer._amp_stash.params_have_scaled_gradients = False\n            # For future fused optimizers that enable sync-free dynamic loss scaling,\n            # should_skip will always be False.\n            should_skip = False if delay_overflow_check else loss_scaler.update_scale()\n            if should_skip:\n                for optimizer in optimizers:\n                    if not optimizer._amp_stash.already_patched:\n                        # Close on loss_scaler and loss_id as well, to be safe.  Probably not\n                        # necessary because amp.scale_loss is already creating a temporary scope.\n                        def patch_step(opt, loss_scaler, loss_id):\n                            opt_step = opt.step\n                            def skip_step(closure=None):\n                                if closure is not None:\n                                    raise RuntimeError(\"Currently, Amp does not support closure use with optimizers.\")\n                                maybe_print((\"Gradient overflow.  Skipping step, loss scaler \" +\n                                             \"{} reducing loss scale to {}\").format(loss_id,\n                                             loss_scaler.loss_scale()))\n                                # TODO:  I don't like the special casing for different optimizer implementations.\n                                # Maybe skip should delegate to a method owned by the optimizers themselves.\n                                if hasattr(opt._amp_stash, \"all_fp32_from_fp16_params\"):\n                                    # Clear the master grads that wouldn't be zeroed by model.zero_grad()\n                                    for param in opt._amp_stash.all_fp32_from_fp16_params:\n                                        param.grad = None\n                                if hasattr(opt, \"most_recent_scale\"):\n                                    opt.most_recent_scale = 1.0\n                                    opt.scale_set_by_backward = False\n                                opt.step = opt_step\n                                opt._amp_stash.already_patched = False\n                            return skip_step\n                        optimizer.step = patch_step(optimizer, loss_scaler, loss_id)\n                        optimizer._amp_stash.already_patched = True\n\n    # Probably ok to skip this if not delay_unscale\n    if _amp_state.opt_properties.patch_torch_functions:\n        _amp_state.handle._clear_cache()\n\n\n# Free function version of AmpHandle.disable_casts, another step on the\n# path to removing the concept of \"AmpHandle\"\n@contextlib.contextmanager\ndef disable_casts():\n    _amp_state.handle._is_active = False\n    yield\n    _amp_state.handle._is_active = True\n\n\nclass AmpHandle(object):\n    def __init__(self, loss_scale=\"dynamic\", enable_caching=True, verbose=False):\n        self._enable_caching = enable_caching\n        self._verbose = verbose\n        self._cache = dict()\n        self._default_scaler = LossScaler(loss_scale)\n        self._is_active = True\n        self._all_wrappers = []\n\n    def is_active(self):\n        return self._is_active\n\n    @contextlib.contextmanager\n    def _disable_casts(self):\n        self._is_active = False\n        yield\n        self._is_active = True\n\n    def wrap_optimizer(self, optimizer, num_loss=1):\n        self._default_scaler = None\n        return OptimWrapper(optimizer, self, num_loss)\n\n    @contextlib.contextmanager\n    def scale_loss(self, loss, optimizer):\n        raise RuntimeError(\"The old Amp API is no longer supported.  Please move to the new API, \"\n            \"documented here:  https://nvidia.github.io/apex/amp.html.  Transition guide:  \"\n            \"https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users\")\n\n        if not self.is_active():\n            yield loss\n            return\n\n        if self._default_scaler is None:\n            raise RuntimeError(\n                'After calling `handle.wrap_optimizer()`, you must explicitly ' +\n                'use `optimizer.scale_loss(loss)`.')\n\n        # TODO: this code block is duplicated here and `opt.py`. Unify.\n        loss_scale = self._default_scaler.loss_scale()\n        yield loss * loss_scale\n\n        self._default_scaler.clear_overflow_state()\n        self._default_scaler.unscale(\n            master_params(optimizer),\n            master_params(optimizer),\n            loss_scale)\n        should_skip = self._default_scaler.update_scale()\n        if should_skip:\n            optimizer_step = optimizer.step\n            def skip_step():\n                maybe_print('Gradient overflow, skipping update')\n                optimizer.step = optimizer_step\n            optimizer.step = skip_step\n\n        self._clear_cache()\n\n    def _clear_cache(self):\n        self._cache.clear()\n\n    # Experimental support for saving / restoring uncasted versions of functions\n    def _save_func(self, mod, fn, func):\n        self._all_wrappers.append((mod, fn, func))\n\n    def _deactivate(self):\n        for mod, fn, func in self._all_wrappers:\n            utils.set_func(mod, fn, func)\n        self._all_wrappers = []\n\n    @property\n    def has_cache(self):\n        return self._enable_caching\n\n    @property\n    def cache(self):\n        return self._cache\n\n    def remove_cache(self, param):\n        if self.has_cache and param in self.cache:\n            del self.cache[param]\n\n    @property\n    def verbose(self):\n        return self._verbose\n\nclass NoOpHandle(object):\n    def is_active(self):\n        return False\n\n    @contextlib.contextmanager\n    def _disable_casts(self):\n        yield\n\n    def wrap_optimizer(self, optimizer, num_loss=1):\n        return OptimWrapper(optimizer, self, num_loss)\n\n    @contextlib.contextmanager\n    def scale_loss(self, loss, optimizer):\n        yield loss\n\n    @property\n    def has_cache(self):\n        return False\n\n    @property\n    def verbose(self):\n        return False\n\n    def _clear_cache(self):\n        pass\n\n    def _deactivate(self):\n        pass\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/lists/__init__.py",
    "content": ""
  },
  {
    "path": "KoSentenceT5/apex/amp/lists/functional_overrides.py",
    "content": "\n# TODO: think about the following two. They do weird things.\n# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)\n# - torch.nn.utils.weight_norm\n\n# Notes:\n# F.instance_norm uses batch_norm internally. Which correctly handles\n#   fp16 in/out with fp32 weights. So we shouldn't do anything for\n#   either of these.\n# F.normalize calls `input.norm()` internally, so it's redundant, but\n#   kept here in case impl. changes.\n# F.cosine_similarity is same: calls `x.norm()` internally.\n\nimport torch.nn.functional\n\nMODULE = torch.nn.functional\n\nFP16_FUNCS = [\n    'conv1d',\n    'conv2d',\n    'conv3d',\n    'conv_transpose1d',\n    'conv_transpose2d',\n    'conv_transpose3d',\n    'conv_tbc', # Undocumented / maybe new?\n    'linear',\n]\n\nFP32_FUNCS = [\n\n    # Interpolation/Upsampling TODO:  Remove for 1.2\n    'interpolate',\n    'grid_sample',\n\n    # Pointwise\n    'softplus',\n    'softmin',\n    'log_softmax',\n    'softmax',\n    'gelu',\n    \n    # Normalization\n    'layer_norm',\n    'group_norm',\n    'local_response_norm',\n    'normalize',\n    'cosine_similarity',\n\n    # Loss functions\n    # TODO: which of these can be fp16?\n    'poisson_nll_loss',\n    'cosine_embedding_loss',\n    'cross_entropy',\n    'hinge_embedding_loss',\n    'kl_div',\n    'l1_loss',\n    'mse_loss',\n    'margin_ranking_loss',\n    'multilabel_margin_loss',\n    'multilabel_soft_margin_loss',\n    'multi_margin_loss',\n    'nll_loss',\n    'binary_cross_entropy_with_logits',\n    'smooth_l1_loss',\n    'soft_margin_loss',\n    'triplet_margin_loss',\n    'ctc_loss'\n]\n\nBANNED_FUNCS = [\n    ('binary_cross_entropy',\n     (\"\\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` \"\n      \"It requires that the output of the previous function be already a FloatTensor. \\n\\n\"\n      \"Most models have a Sigmoid right before BCELoss. In that case, you can use\\n\"\n      \"    torch.nn.BCEWithLogitsLoss\\nto combine Sigmoid+BCELoss into a single layer \"\n      \"that is compatible with amp.\\nAnother option is to add\\n\"\n      \"    amp.register_float_function(torch, 'sigmoid')\\nbefore calling `amp.init()`.\\n\"\n      \"If you _really_ know what you are doing, you can disable this warning by passing \"\n      \"allow_banned=True to `amp.init()`.\"))\n]\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/lists/tensor_overrides.py",
    "content": "from .. import compat\nfrom . import torch_overrides\n\nimport importlib\n\nimport torch\n\n# if compat.variable_is_tensor() and not compat.tensor_is_variable():\nMODULE = torch.Tensor\n# else:\n#     MODULE = torch.autograd.Variable\n\n\nFP16_FUNCS = compat.filter_attrs(MODULE, [\n    '__matmul__',\n])\n\nFP32_FUNCS = compat.filter_attrs(MODULE, [\n    '__ipow__',\n    '__pow__',\n    '__rpow__',\n\n    # Cast to fp32 before transfer to CPU\n    'cpu',\n])\n\nCASTS = compat.filter_attrs(MODULE, [\n    '__add__',\n    '__div__',\n    '__eq__',\n    '__ge__',\n    '__gt__',\n    '__iadd__',\n    '__idiv__',\n    '__imul__',\n    '__isub__',\n    '__itruediv__',\n    '__le__',\n    '__lt__',\n    '__mul__',\n    '__ne__',\n    '__radd__',\n    '__rdiv__',\n    '__rmul__',\n    '__rsub__',\n    '__rtruediv__',\n    '__sub__',\n    '__truediv__',\n])\n\n# None of these, but here to make code cleaner.\nSEQUENCE_CASTS = []\n\n# We need to grab all the methods from torch_overrides and add them to\n# the Tensor lists as well, as almost all methods are duplicated\n# between `torch` and `torch.Tensor` (and check with `hasattr`,\n# because a few random ones aren't defined on Tensor)\n_self_mod = importlib.import_module(__name__)\nfor attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:\n    lst = getattr(_self_mod, attrname)\n    for fn in getattr(torch_overrides, attrname):\n        if hasattr(MODULE, fn):\n            lst.append(fn)\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/lists/torch_overrides.py",
    "content": "import torch\n\nfrom .. import utils\n\nMODULE = torch\n\nFP16_FUNCS = [\n    # Low level functions wrapped by torch.nn layers.\n    # The wrapper layers contain the weights which are then passed in as a parameter\n    # to these functions.\n    'conv1d',\n    'conv2d',\n    'conv3d',\n    'conv_transpose1d',\n    'conv_transpose2d',\n    'conv_transpose3d',\n    'conv_tbc',\n    'prelu',\n\n    # BLAS\n    'addmm',\n    'addmv',\n    'addr',\n    'matmul',\n    'mm',\n    'mv',\n]\n\nFP32_FUNCS = [\n    # Pointwise\n    'acos',\n    'asin',\n    'cosh',\n    'erfinv',\n    'exp',\n    'expm1',\n    'log',\n    'log10',\n    'log2',\n    'reciprocal',\n    'rsqrt',\n    'sinh',\n    'tan',\n\n    # Other math\n    'pow',\n\n    # Reduction\n    'cumprod',\n    'cumsum',\n    'dist',\n    # 'mean',\n    'norm',\n    'prod',\n    'std',\n    'sum',\n    'var',\n\n    # Misc\n    'renorm'\n]\n\nversion_strings = torch.__version__.split('.')\nversion_major = version_strings[0]\nversion_minor = version_strings[1]\nversion_num = float(version_major + \".\" + version_minor)\n# Before torch 1.1, mean must be blacklisted.\nif version_num < 1.1:\n    FP32_FUNCS.append('mean')\n\n# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We\n# check the CUDA version -- if at least 9.1, then put the bmm\n# functions on the fp16 list. Otherwise, put them on the fp32 list.\n_bmms = ['addbmm',\n         'baddbmm',\n         'bmm']\n\nif utils.is_cuda_enabled():\n  # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802\n  if utils.get_cuda_version() >= (9, 1, 0):\n      FP16_FUNCS.extend(_bmms)\n  else:\n      FP32_FUNCS.extend(_bmms)\n\n# Multi-tensor fns that may need type promotion\nCASTS = [\n    # Multi-tensor math\n    'addcdiv',\n    'addcmul',\n    'atan2',\n    'cross',\n    'bilinear',\n    'dot',\n\n    # Element-wise _or_ tensor-wise math\n    'add',\n    'div',\n    'mul',\n\n    # Comparison\n    'eq',\n    'equal',\n    'ge',\n    'gt',\n    'le',\n    'lt',\n    'ne'\n]\n\n# Functions that take sequence arguments. We need to inspect the whole\n# sequence and cast to the widest type.\nSEQUENCE_CASTS = [\n    'cat',\n    'stack'\n]\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/opt.py",
    "content": "import contextlib\nimport warnings\n\nfrom .scaler import LossScaler, master_params\nfrom ._amp_state import maybe_print\n\nimport numpy as np\n\nclass OptimWrapper(object):\n    def __init__(self, optimizer, amp_handle, num_loss):\n        self._optimizer = optimizer\n        self._amp_handle = amp_handle\n        self._num_loss = num_loss\n        self._loss_idx = 0\n        self._skip_next = [False] * num_loss\n        self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]\n\n    @contextlib.contextmanager\n    def scale_loss(self, loss):\n        if not self._amp_handle.is_active():\n            yield loss\n            return\n\n        # When there are multiple losses per-optimizer, we need\n        # to save out current grad accumulation, since we won't be\n        # able to unscale this particulare loss once the grads are\n        # all mixed together.\n        cached_grads = []\n        if self._loss_idx > 0:\n            for p in master_params(self._optimizer):\n                if p.grad is not None:\n                    cached_grads.append(p.grad.data.detach().clone())\n                else:\n                    cached_grads.append(None)\n            self._optimizer.zero_grad()\n\n        loss_scale = self._cur_loss_scaler().loss_scale()\n        yield loss * loss_scale\n\n        self._cur_loss_scaler().clear_overflow_state()\n        self._cur_loss_scaler().unscale(\n            master_params(self._optimizer),\n            master_params(self._optimizer),\n            loss_scale)\n        self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()\n        self._loss_idx += 1\n\n        if len(cached_grads) > 0:\n            for p, cached_grad in zip(master_params(self._optimizer),\n                                      cached_grads):\n                if cached_grad is not None:\n                    p.grad.data.add_(cached_grad)\n            cached_grads = []\n\n    def _cur_loss_scaler(self):\n        assert 0 <= self._loss_idx < self._num_loss\n        return self._loss_scaler[self._loss_idx]\n\n    def step(self, closure=None):\n        if not self._amp_handle.is_active():\n            return self._optimizer.step(closure=closure)\n\n        self._loss_idx = 0\n\n        for group in self._optimizer.param_groups:\n            for p in group['params']:\n                self._amp_handle.remove_cache(p)\n\n        if closure is not None:\n            raise NotImplementedError(\n                'The `closure` argument is unsupported by the amp ' +\n                'optimizer wrapper.')\n        if any(self._skip_next):\n            maybe_print('Gradient overflow, skipping update')\n            self._skip_next = [False] * self._num_loss\n        else:\n            return self._optimizer.step(closure=closure)\n\n    # Forward any attribute lookups\n    def __getattr__(self, attr):\n        return getattr(self._optimizer, attr)\n\n    # Forward all torch.optim.Optimizer methods\n    def __getstate__(self):\n        return self._optimizer.__getstate__()\n\n    def __setstate__(self):\n        return self._optimizer.__setstate__()\n\n    def __repr__(self):\n        return self._optimizer.__repr__()\n\n    def state_dict(self):\n        return self._optimizer.state_dict()\n\n    def load_state_dict(self, state_dict):\n        return self._optimizer.load_state_dict(state_dict)\n\n    def zero_grad(self):\n        return self._optimizer.zero_grad()\n\n    def add_param_group(self, param_group):\n        return self._optimizer.add_param_group(param_group)\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/rnn_compat.py",
    "content": "from . import utils, wrap\n\nimport torch\n_VF = torch._C._VariableFunctions\nRNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']\n\ndef _gen_VF_wrapper(name):\n    def wrapper(*args, **kwargs):\n        return getattr(_VF, name)(*args, **kwargs)\n    return wrapper\n\n# Some python magic to generate an object that has the rnn cell functions\n# defined on it, all of which call into corresponding _VF version.\n# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named \"_VF\"\n# imported at module scope within torch.nn.modules.rnn).  This should\n# not affect third-party importers of _VF.py.\nclass VariableFunctionsShim(object):\n    def __init__(self):\n        for name in RNN_NAMES:\n            for suffix in ['', '_cell']:\n               fn_name = name + suffix\n               setattr(self, fn_name, _gen_VF_wrapper(fn_name))\n\ndef has_old_rnns():\n    try:\n        torch.nn.backends.thnn.backend.LSTMCell\n        return True\n    except:\n        return False\n\ndef whitelist_rnn_cells(handle, verbose):\n    # Different module + function names in old/new RNN cases\n    if has_old_rnns():\n        fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']\n        mod = torch.nn.backends.thnn.backend\n    else:\n        fn_names = [x + '_cell' for x in RNN_NAMES]\n        mod = torch.nn.modules.rnn._VF\n        assert isinstance(mod, VariableFunctionsShim)\n\n    # Insert casts on cell functions\n    for fn in fn_names:\n        wrap.cached_cast(mod, fn, utils.maybe_half, handle,\n                         try_caching=True, verbose=verbose)\n\n    if has_old_rnns():\n        # Special handling of `backward` for fused gru / lstm:\n        # The `backward` method calls Tensor.sum() (blacklist) internally,\n        # and then the resulting grad_input has the wrong type.\n        # TODO: where else is this a problem?\n        for rnn_type in ['GRUFused', 'LSTMFused']:\n            mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)\n            wrap.disable_casts(mod, 'backward', handle)\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/scaler.py",
    "content": "import torch\nfrom ..multi_tensor_apply import multi_tensor_applier\nfrom ._amp_state import _amp_state, master_params, maybe_print\nfrom itertools import product\n\ndef scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):\n    # Exception handling for 18.04 compatibility\n    if check_overflow:\n        cpu_sum = float(model_grad.float().sum())\n        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n            return True\n\n    if master_grad is not model_grad: # copy_ probably internally short-circuits this\n        master_grad.copy_(model_grad)\n    if scale != 1.0:\n        master_grad.mul_(scale)\n    return False\n\ndef axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):\n    # Exception handling for 18.04 compatibility\n    if check_overflow:\n        cpu_sum = float(model_grad.float().sum())\n        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n            return True\n\n    # if master_grad is not model_grad: # copy_ probably internally short-circuits this\n    #     master_grad.copy_(model_grad)\n    assert stashed_grad.dtype == master_grad.dtype\n    converted_model_grad = model_grad.data.to(master_grad.dtype)\n    master_grad.data = a*converted_model_grad.data + b*stashed_grad.data\n    return False\n\nclass LossScaler(object):\n    warned_no_fused_kernel = False\n    warned_unscaling_non_fp32_grad = False\n    has_fused_kernel = False\n\n    def __init__(self,\n                 loss_scale,\n                 init_scale=2.**16,\n                 scale_factor=2.,\n                 scale_window=2000,\n                 min_loss_scale=None,\n                 max_loss_scale=2.**24):\n        if loss_scale == \"dynamic\":\n            self.dynamic = True\n            self._loss_scale = min(max_loss_scale, init_scale)\n        else:\n            self.dynamic = False\n            self._loss_scale = loss_scale\n        self._max_loss_scale = max_loss_scale\n        self._min_loss_scale = min_loss_scale\n        self._scale_seq_len = scale_window\n        self._unskipped = 0\n        self._has_overflow = False\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        if multi_tensor_applier.available:\n            import amp_C\n            LossScaler.has_fused_kernel = multi_tensor_applier.available\n            LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale\n            LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby\n        else:\n            if not LossScaler.warned_no_fused_kernel:\n                maybe_print(\n                    \"Warning:  multi_tensor_applier fused unscale kernel is unavailable, \"\n                    \"possibly because apex was installed without --cuda_ext --cpp_ext. \"\n                    \"Using Python fallback.  Original ImportError was: \" +\n                    repr(multi_tensor_applier.import_err),\n                    True)\n            LossScaler.has_fused_kernel = False\n            LossScaler.warned_no_fused_kernel = True\n\n    def loss_scale(self):\n        return self._loss_scale\n\n    def unscale_python(self, model_grads, master_grads, scale):\n        for model, master in zip(model_grads, master_grads):\n            if model is not None:\n                if not LossScaler.warned_unscaling_non_fp32_grad:\n                    if master.dtype != torch.float32:\n                        maybe_print(\n                            \"Attempting to unscale a grad with type {} \".format(master.type()) +\n                            \"Unscaling non-fp32 grads may indicate an error. \"\n                            \"When using Amp, you don't need to call .half() on your model.\")\n                        LossScaler.warned_unscaling_non_fp32_grad = True\n                self._has_overflow = scale_check_overflow_python(model,\n                                                                 master,\n                                                                 1./scale,\n                                                                 self.dynamic)\n                if self._has_overflow and self.dynamic:\n                    break\n\n    # unused_scale keeps some of the old API alive for hopefully a short time.\n    def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):\n        if self._has_overflow:\n            return\n\n        scale = self._loss_scale\n        if scale_override is not None:\n            scale = scale_override\n\n        if scale == 1.0 and models_are_masters and not self.dynamic:\n            return\n\n        if LossScaler.has_fused_kernel:\n            # if (not LossScaler.warned_unscaling_non_fp32_grad\n            #     and master_grads[0].dtype == torch.float16):\n            #     print(\"Warning:  unscaling grads that are not FP32. \"\n            #           \"Unscaling non-fp32 grads may indicate an error. \"\n            #           \"When using Amp, you don't need to call .half() on your model.\")\n            #     # Setting this to True unconditionally allows the possibility of an escape\n            #     # if never-before-seen non-fp32 grads are created in some later iteration.\n            #     LossScaler.warned_unscaling_non_fp32_grad = True\n            multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,\n                                 self._overflow_buf,\n                                 [model_grads, master_grads],\n                                 1./scale)\n        else:\n            self.unscale_python(model_grads, master_grads, scale)\n\n        # Defer to update_scale\n        # If the fused kernel is available, we only need one D2H memcopy and sync.\n        # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:\n        #     self._has_overflow = self._overflow_buf.item()\n\n    def unscale_with_stashed_python(self,\n                                    model_grads,\n                                    stashed_master_grads,\n                                    master_grads,\n                                    a,\n                                    b):\n        for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):\n            if model is None and stashed is None:\n                continue\n            else:\n                if not LossScaler.warned_unscaling_non_fp32_grad:\n                    if master.dtype != torch.float32:\n                        maybe_print(\n                            \"Attempting to unscale a grad with type {} \".format(master.type()) +\n                            \"Unscaling non-fp32 grads may indicate an error. \"\n                            \"When using Amp, you don't need to call .half() on your model.\")\n                        LossScaler.warned_unscaling_non_fp32_grad = True\n                self._has_overflow = axpby_check_overflow_python(model,\n                                                                 stashed,\n                                                                 master,\n                                                                 a,\n                                                                 b,\n                                                                 self.dynamic)\n                if self._has_overflow and self.dynamic:\n                    break\n\n    def unscale_with_stashed(self,\n                             model_grads,\n                             stashed_master_grads,\n                             master_grads,\n                             scale_override=None):\n        if self._has_overflow:\n            return\n\n        grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0\n        if scale_override is not None:\n            grads_have_scale, stashed_have_scale, out_scale = scale_override\n\n        if LossScaler.has_fused_kernel:\n            if (not LossScaler.warned_unscaling_non_fp32_grad\n                and master_grads[0].dtype == torch.float16):\n                print(\"Warning:  unscaling grads that are not FP32. \"\n                      \"Unscaling non-fp32 grads may indicate an error. \"\n                      \"When using Amp, you don't need to call .half() on your model.\")\n                # Setting this to True unconditionally allows the possibility of an escape\n                # if never-before-seen non-fp32 grads are created in some later iteration.\n                LossScaler.warned_unscaling_non_fp32_grad = True\n            multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,\n                                 self._overflow_buf,\n                                 [model_grads, stashed_master_grads, master_grads],\n                                 out_scale/grads_have_scale,   # 1./scale,\n                                 out_scale/stashed_have_scale, # 1.0,\n                                 0) # check only arg 0, aka the incoming model grads, for infs\n        else:\n            self.unscale_with_stashed_python(model_grads,\n                                             stashed_master_grads,\n                                             master_grads,\n                                             out_scale/grads_have_scale,\n                                             out_scale/stashed_have_scale)\n\n        # Defer to update_scale\n        # If the fused kernel is available, we only need one D2H memcopy and sync.\n        # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:\n        #     self._has_overflow = self._overflow_buf.item()\n\n    def clear_overflow_state(self):\n        self._has_overflow = False\n        if self.has_fused_kernel:\n            self._overflow_buf.zero_()\n\n    # Separate so unscale() can be called more that once before updating.\n    def update_scale(self):\n        # If the fused kernel is available, we only need one D2H memcopy and sync.\n        if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:\n            self._has_overflow = self._overflow_buf.item()\n\n        if self._has_overflow and self.dynamic:\n            should_skip = True\n            if(self._min_loss_scale):\n                self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)\n            else:\n                self._loss_scale = self._loss_scale/2.\n            self._unskipped = 0\n        else:\n            should_skip = False\n            self._unskipped += 1\n\n        if self._unskipped == self._scale_seq_len and self.dynamic:\n            self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)\n            self._unskipped = 0\n\n        return should_skip\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/utils.py",
    "content": "from . import compat\n\nimport functools\nimport itertools\n\nimport torch\n\ndef is_cuda_enabled():\n    return torch.version.cuda is not None\n\ndef get_cuda_version():\n    return tuple(int(x) for x in torch.version.cuda.split('.'))\n\ndef is_fp_tensor(x):\n    if is_nested(x):\n        # Fast-fail version of all(is_fp_tensor)\n        for y in x:\n            if not is_fp_tensor(y):\n                return False\n        return True\n    return compat.is_tensor_like(x) and compat.is_floating_point(x)\n\ndef is_nested(x):\n    return isinstance(x, tuple) or isinstance(x, list)\n\ndef should_cache(x):\n    if is_nested(x):\n        # Fast-fail version of all(should_cache)\n        for y in x:\n            if not should_cache(y):\n                return False\n        return True\n    return isinstance(x, torch.nn.parameter.Parameter) and \\\n        type_string(x) == 'FloatTensor'\n\ndef collect_fp_tensor_types(args, kwargs):\n    def collect_types(x, types):\n        if is_nested(x):\n            for y in x:\n                collect_types(y, types)\n        else:\n            types.add(type_string(x))\n\n    all_args = itertools.chain(args, kwargs.values())\n    types = set()\n    for x in all_args:\n        if is_fp_tensor(x):\n            collect_types(x, types)\n    return types\n\ndef type_string(x):\n    return x.type().split('.')[-1]\n\ndef maybe_half(x, name='', verbose=False):\n    if is_nested(x):\n        return type(x)([maybe_half(y) for y in x])\n\n    if not x.is_cuda or type_string(x) == 'HalfTensor':\n        return x\n    else:\n        if verbose:\n            print('Float->Half ({})'.format(name))\n        return x.half()\n\ndef maybe_float(x, name='', verbose=False):\n    if is_nested(x):\n        return type(x)([maybe_float(y) for y in x])\n\n    if not x.is_cuda or type_string(x) == 'FloatTensor':\n        return x\n    else:\n        if verbose:\n            print('Half->Float ({})'.format(name))\n        return x.float()\n\n# NB: returneds casted `args`, mutates `kwargs` in-place\ndef casted_args(cast_fn, args, kwargs):\n    new_args = []\n    for x in args:\n        if is_fp_tensor(x):\n            new_args.append(cast_fn(x))\n        else:\n            new_args.append(x)\n    for k in kwargs:\n        val = kwargs[k]\n        if is_fp_tensor(val):\n            kwargs[k] = cast_fn(val)\n    return new_args\n\ndef cached_cast(cast_fn, x, cache):\n    if is_nested(x):\n        return type(x)([cached_cast(y) for y in x])\n    if x in cache:\n        cached_x = cache[x]\n        if x.requires_grad and cached_x.requires_grad:\n            # Make sure x is actually cached_x's autograd parent.\n            if cached_x.grad_fn.next_functions[1][0].variable is not x:\n                raise RuntimeError(\"x and cache[x] both require grad, but x is not \"\n                                   \"cache[x]'s parent.  This is likely an error.\")\n        # During eval, it's possible to end up caching casted weights with\n        # requires_grad=False.  On the next training iter, if cached_x is found\n        # and reused from the cache, it will not actually have x as its parent.\n        # Therefore, we choose to invalidate the cache (and force refreshing the cast)\n        # if x.requires_grad and cached_x.requires_grad do not match.\n        #\n        # During eval (i.e. running under with torch.no_grad()) the invalidation\n        # check would cause the cached value to be dropped every time, because\n        # cached_x would always be created with requires_grad=False, while x would\n        # still have requires_grad=True.  This would render the cache effectively\n        # useless during eval.  Therefore, if we are running under the no_grad()\n        # context manager (torch.is_grad_enabled=False) we elide the invalidation\n        # check, and use the cached value even though its requires_grad flag doesn't\n        # match.  During eval, we don't care that there's no autograd-graph\n        # connection between x and cached_x.\n        if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:\n            del cache[x]\n        else:\n            return cached_x\n\n    casted_x = cast_fn(x)\n    cache[x] = casted_x\n    return casted_x\n\ndef verbosify(cast_fn, fn_name, verbose):\n    if verbose:\n        return functools.partial(cast_fn, name=fn_name, verbose=verbose)\n    else:\n        return cast_fn\n\ndef as_inplace(fns):\n    for x in fns:\n        yield x + '_'\n\ndef has_func(mod, fn):\n    if isinstance(mod, dict):\n        return fn in mod\n    else:\n        return hasattr(mod, fn)\n\ndef get_func(mod, fn):\n    if isinstance(mod, dict):\n        return mod[fn]\n    else:\n        return getattr(mod, fn)\n\ndef set_func(mod, fn, new_fn):\n    if isinstance(mod, dict):\n        mod[fn] = new_fn\n    else:\n        setattr(mod, fn, new_fn)\n\ndef set_func_save(handle, mod, fn, new_fn):\n    cur_fn = get_func(mod, fn)\n    handle._save_func(mod, fn, cur_fn)\n    set_func(mod, fn, new_fn)\n\n# A couple problems get solved here:\n# - The flat_weight buffer is disconnected from autograd graph,\n#   so the fp16 weights need to be derived from the input weights\n#   to this forward call, not the flat buffer.\n# - The ordering of weights in the flat buffer is...idiosyncratic.\n# First problem is solved with combination of set_ (to set up\n# correct storage) and copy_ (so the fp16 weight derives from the\n# fp32 one in autograd.\n# Second is solved by doing ptr arithmetic on the fp32 weights\n# to derive the correct offset.\n#\n# TODO: maybe this should actually use\n# `torch._cudnn_rnn_flatten_weight`? But then I need to call\n# on first iter and cache the right offsets. Ugh.\ndef synthesize_flattened_rnn_weights(fp32_weights,\n                                     fp16_flat_tensor,\n                                     rnn_fn='',\n                                     verbose=False):\n    fp16_weights = []\n    fp32_base_ptr = fp32_weights[0][0].data_ptr()\n    for layer_weights in fp32_weights:\n        fp16_layer_weights = []\n        for w_fp32 in layer_weights:\n            w_fp16 = w_fp32.new().half()\n            offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()\n            w_fp16.set_(fp16_flat_tensor.storage(),\n                        offset,\n                        w_fp32.shape)\n            w_fp16.copy_(w_fp32)\n            if verbose:\n                print('Float->Half ({})'.format(rnn_fn))\n            fp16_layer_weights.append(w_fp16)\n        fp16_weights.append(fp16_layer_weights)\n    return fp16_weights\n\n# Roughly same as above, just the `fp32_weights` aren't nested.\n# Code kept separate for readability.\ndef new_synthesize_flattened_rnn_weights(fp32_weights,\n                                         fp16_flat_tensor,\n                                         rnn_fn='',\n                                         verbose=False):\n    fp16_weights = []\n    fp32_base_ptr = fp32_weights[0].data_ptr()\n    for w_fp32 in fp32_weights:\n        w_fp16 = w_fp32.new().half()\n        offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()\n        w_fp16.set_(fp16_flat_tensor.storage(),\n                    offset,\n                    w_fp32.shape)\n        w_fp16.copy_(w_fp32)\n        if verbose:\n            print('Float->Half ({})'.format(rnn_fn))\n        fp16_weights.append(w_fp16)\n    return fp16_weights\n"
  },
  {
    "path": "KoSentenceT5/apex/amp/wrap.py",
    "content": "from . import compat\nfrom . import utils\nfrom ._amp_state import _amp_state\nfrom . import rnn_compat\n\nimport functools\n\nimport torch\n\ndef make_cast_wrapper(orig_fn, cast_fn, handle,\n                      try_caching=False):\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        if not handle.is_active():\n            return orig_fn(*args, **kwargs)\n\n        if try_caching and handle.has_cache:\n            args = list(args)\n            for i in range(len(args)):\n                if utils.should_cache(args[i]):\n                    args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)\n            for k in kwargs:\n                if utils.should_cache(kwargs[k]):\n                    kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)\n        new_args = utils.casted_args(cast_fn,\n                                     args,\n                                     kwargs)\n        return orig_fn(*new_args, **kwargs)\n    return wrapper\n\ndef cached_cast(mod, fn, cast_fn, handle,\n                try_caching=False, verbose=False):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    cast_fn = utils.verbosify(cast_fn, fn, verbose)\n    wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\n# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`\n# Annoyingly, make_promote_wrapper still uses the global handle.  Once everyone\n# is on the new API and I am free to get rid of handle, I can clean this up.\ndef make_promote_wrapper(orig_fn, cast_fn, handle=None):\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        if not _amp_state.handle.is_active():\n            return orig_fn(*args, **kwargs)\n\n        types = utils.collect_fp_tensor_types(args, kwargs)\n\n        if len(types) <= 1:\n            return orig_fn(*args, **kwargs)\n        elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):\n            new_args = utils.casted_args(cast_fn,\n                                         args,\n                                         kwargs)\n            return orig_fn(*new_args, **kwargs)\n        else:\n            raise NotImplementedError('Do not know how to handle ' +\n                                      'these types to promote: {}'\n                                      .format(types))\n    return wrapper\n\ndef promote(mod, fn, handle, verbose=False):\n    orig_fn = utils.get_func(mod, fn)\n    maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)\n    wrapper = make_promote_wrapper(orig_fn, maybe_float)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef sequence_promote(mod, fn, handle, verbose=False):\n    orig_fn = utils.get_func(mod, fn)\n    maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)\n    @functools.wraps(orig_fn)\n    def wrapper(seq, *args, **kwargs):\n        if not _amp_state.handle.is_active():\n            return orig_fn(seq, *args, **kwargs)\n\n        types = set([utils.type_string(x) for x in seq])\n        if len(types) <= 1:\n            return orig_fn(seq, *args, **kwargs)\n        elif types == set(['HalfTensor', 'FloatTensor']):\n            cast_seq = utils.casted_args(maybe_float,\n                                         seq, {})\n            return orig_fn(cast_seq, *args, **kwargs)\n        else:\n            # TODO: other mixed-type cases aren't due to amp.\n            #       Just pass through?\n            return orig_fn(seq, *args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef promote_match_arg0(mod, fn, handle, verbose=False):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(arg0, *args, **kwargs):\n        assert compat.is_tensor_like(arg0)\n        if not _amp_state.handle.is_active():\n            return orig_fn(arg0, *args, **kwargs)\n\n        if utils.type_string(arg0) == 'HalfTensor':\n            cast_fn = utils.maybe_half\n        elif utils.type_string(arg0) == 'FloatTensor':\n            cast_fn = utils.maybe_float\n        else:\n            return orig_fn(arg0, *args, **kwargs)\n        cast_fn = utils.verbosify(cast_fn, fn, verbose)\n        new_args = utils.casted_args(cast_fn, args, kwargs)\n        return orig_fn(arg0, *new_args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef err_if_any_half(mod, fn, handle, custom_err_msg=None):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        types = utils.collect_fp_tensor_types(args, kwargs)\n        if 'HalfTensor' in types:\n            if custom_err_msg:\n                raise NotImplementedError(custom_err_msg)\n            else:\n                raise NotImplementedError('Cannot call in-place function ' +\n                                          '{} with fp16 arguments.'.format(fn))\n        else:\n            return orig_fn(*args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef err_if_arg0_half(mod, fn, handle, verbose=False):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(arg0, *args, **kwargs):\n        assert compat.is_tensor_like(arg0)\n        if utils.type_string(arg0) == 'HalfTensor':\n            raise NotImplementedError('Cannot call in-place method ' +\n                                      '{} on fp16 Tensors.'.format(fn))\n        else:\n            cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)\n            new_args = utils.casted_args(cast_fn, args, kwargs)\n            return orig_fn(arg0, *new_args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\n# Current RNN approach:\n# - Wrap top-level `RNN` function in thnn backend\n# - Will call into either CudnnRNN or AutogradRNN\n#  - Each of these are factory functions that return a per-iter\n#    `forward` function\n# - We interpose on the factory function to:\n#   1) Interpose on the actual forward function and put in casts\n#   2) Insert an fp16 `flat_weight` if necessary\ndef rnn_cast(backend, fn, handle, verbose=False):\n    orig_rnn = utils.get_func(backend, fn)\n    @functools.wraps(orig_rnn)\n    def rnn_wrapper(*args, **kwargs):\n        flat_weight = kwargs.get('flat_weight')\n        if flat_weight is not None:\n            # We replace `flat_weight` with an uninitialized fp16\n            # Tensor. The \"actual\" weight tensors (provided in `forward`),\n            # will then be set up as ptrs into the buffer and have the\n            # corresponding fp32 values copied in.\n            # We need to call `copy` on the \"actual\" weights so that the\n            # autograd graph correctly backprops from the wgrads computed\n            # inside cuDNN (on fp16 weights) into the fp32 weights.\n            assert utils.type_string(flat_weight) == 'FloatTensor'\n            if compat.tensor_is_float_tensor() or compat.tensor_is_variable():\n                # Pre-0.4. A little slower, since it zeros out memory.\n                flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)\n            else:\n                flat_weight_fp16 = torch.empty_like(flat_weight,\n                                                    dtype=torch.float16)\n            kwargs['flat_weight'] = flat_weight_fp16\n        else:\n            flat_weight_fp16 = None\n\n        forward = orig_rnn(*args, **kwargs)\n        @functools.wraps(forward)\n        def fwd_wrapper(*fargs, **fkwargs):\n            assert len(fargs) == 3 or len(fargs) == 4\n            inputs, weights, hiddens = fargs[:3]\n            assert utils.is_fp_tensor(inputs)\n            assert isinstance(weights, list)\n            cast_fn = utils.verbosify(utils.maybe_half,\n                                      fn,\n                                      verbose)\n            new_args = []\n\n            # 0) Inputs\n            new_args.append(cast_fn(inputs))\n\n            # 1) Weights\n            if flat_weight_fp16 is not None:\n                fp16_weights = utils.synthesize_flattened_rnn_weights(\n                    weights, flat_weight_fp16, fn, verbose)\n            else:\n                fp16_weights = [[cast_fn(w) for w in layer]\n                                for layer in weights]\n            new_args.append(fp16_weights)\n\n            # 2) Inputs: either a tuple (for LSTM) or single tensor\n            if isinstance(hiddens, tuple):\n                new_args.append(tuple(cast_fn(x) for x in hiddens))\n            elif utils.is_fp_tensor(hiddens):\n                new_args.append(cast_fn(hiddens))\n            else:\n                # Hiddens can, in principle, be `None` -- pass through\n                new_args.append(hiddens)\n\n            # 3) Batch sizes (0.4 or later only)\n            if len(fargs) == 4:\n                new_args.append(fargs[3])\n\n            return forward(*new_args, **fkwargs)\n        return fwd_wrapper\n    utils.set_func_save(handle, backend, fn, rnn_wrapper)\n\ndef new_rnn_cast(fn, handle, verbose=False):\n    # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744\n    # For rnn backend calls that route through _rnn_impls, we must patch the ref\n    # that _rnn_impls stashed.  For rnn backend calls that directly invoke\n    # _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,\n    # which in turn has patched the ref named \"_VF\" in torch.nn.modules.rnn.\n    if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):\n        mod = torch.nn.modules.rnn._rnn_impls\n    else:\n        mod = torch.nn.modules.rnn._VF\n        assert isinstance(mod, rnn_compat.VariableFunctionsShim)\n        fn = fn.lower()\n    orig_fn = utils.get_func(mod, fn)\n    cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        # Exact call signature from modules/rnn.py\n        assert len(args) == 9\n        assert len(kwargs) == 0\n\n        if not _amp_state.handle.is_active():\n            return orig_fn(*args, **kwargs)\n\n        if isinstance(args[6], bool):\n            params_idx = 2 # Not PackedSequence case\n        else:\n            params_idx = 3 # PackedSequence case\n\n        new_args = []\n        for i, arg in enumerate(args):\n            if i == params_idx:\n                num_params = sum([x.numel() for x in arg])\n                fp16_weight_buf = args[0].new_empty((num_params,),\n                                                    dtype=torch.half)\n                casted_weights = utils.new_synthesize_flattened_rnn_weights(\n                    arg, fp16_weight_buf, fn, verbose)\n                new_args.append(casted_weights)\n            elif utils.is_fp_tensor(arg):\n                new_args.append(cast_fn(arg))\n            else:\n                new_args.append(arg)\n\n        return orig_fn(*new_args)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef disable_casts(mod, fn, handle):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        with handle._disable_casts():\n            return orig_fn(*args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/__init__.py",
    "content": ""
  },
  {
    "path": "KoSentenceT5/apex/contrib/bottleneck/__init__.py",
    "content": "from .bottleneck import Bottleneck\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/bottleneck/bottleneck.py",
    "content": "import torch\nfrom torch import nn\nimport fast_bottleneck\n\ndef kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):\n    weight_tensor_nchw = tensor\n    nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)\n\nclass FrozenBatchNorm2d(torch.nn.Module):\n    \"\"\"\n    BatchNorm2d where the batch statistics and the affine parameters are fixed\n    \"\"\"\n    def __init__(self, n):\n        super(FrozenBatchNorm2d, self).__init__()\n        self.register_buffer(\"weight\", torch.ones(n))\n        self.register_buffer(\"bias\", torch.zeros(n))\n        self.register_buffer(\"running_mean\", torch.zeros(n))\n        self.register_buffer(\"running_var\", torch.ones(n))\n\n    def get_scale_bias(self, nhwc=False):\n        scale = self.weight * self.running_var.rsqrt()\n        bias = self.bias - self.running_mean * scale\n        if nhwc:\n            scale = scale.reshape(1, 1, 1, -1)\n            bias = bias.reshape(1, 1, 1, -1)\n        else:\n            scale = scale.reshape(1, -1, 1, 1)\n            bias = bias.reshape(1, -1, 1, 1)\n        return scale, bias\n\n    def forward(self, x):\n        scale, bias = self.get_scale_bias()\n        return x * scale + bias\n\n\n@torch.jit.script\ndef drelu_dscale1(grad_o, output, scale1):\n    relu_mask = (output>0).half()\n    dx_relu = relu_mask * grad_o\n    g1 = dx_relu * scale1\n    return g1, dx_relu\n\n@torch.jit.script\ndef drelu_dscale2(grad_o, output, scale1, scale2):\n    relu_mask = (output>0).half()\n    dx_relu = relu_mask * grad_o\n    g1 = dx_relu * scale1\n    g2 = dx_relu * scale2\n    return g1, g2\n\nclass BottleneckFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):\n        # TODO: clean up order of tensors\n        args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]\n        ctx.downsample = len(conv) > 3\n        if ctx.downsample:\n            args.append(conv[3])\n            args.append(scale[3])\n            args.append(bias[3])\n\n        # weight buffers are always in nhwc while shape can be nhwc or channels_last\n        # here we pass in flag and let c++ handle it\n        # alternatively, we can put all sizes into a fixed format and pass it in\n        outputs = fast_bottleneck.forward(nhwc, stride_1x1, args)\n        ctx.save_for_backward(*(args+outputs))\n        # save relu outputs for drelu\n        ctx.nhwc = nhwc\n        ctx.stride_1x1 = stride_1x1\n        return outputs[2]\n\n    # backward relu is not exposed, MUL with mask used now\n    # only support dgrad\n    @staticmethod\n    def backward(ctx, grad_o):\n        outputs = ctx.saved_tensors[-3:]\n\n        if ctx.downsample:\n            grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])\n        else:\n            grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])\n\n        # create input vector for backward\n        t_list = [*ctx.saved_tensors[0:10]]\n        t_list.append(grad_conv3)\n        t_list.append(grad_conv4)\n\n        # outputs used for wgrad and generating drelu mask\n        t_list.append(outputs[0])\n        t_list.append(outputs[1])\n\n        # in case there is downsample\n        if ctx.downsample:\n            t_list.append(ctx.saved_tensors[10])\n\n        grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list)\n\n        return (None, None, None, None, *grads)\n\nbottleneck_function = BottleneckFunction.apply\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\nclass Bottleneck(torch.nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n    # here we put it at 1x1\n\n    def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,\n                 dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False):\n        super(Bottleneck, self).__init__()\n        if groups != 1:\n            raise RuntimeError('Only support groups == 1')\n        if dilation != 1:\n            raise RuntimeError('Only support dilation == 1')\n        if norm_func == None:\n            norm_func = FrozenBatchNorm2d\n        else:\n            raise RuntimeError('Only support frozen BN now.')\n\n        if stride != 1 or in_channels != out_channels:\n            self.downsample = nn.Sequential(\n                conv1x1(in_channels, out_channels, stride),\n                norm_func(out_channels),\n            )\n        else:\n            self.downsample = None\n\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)\n        self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)\n        self.conv3 = conv1x1(bottleneck_channels, out_channels)\n        self.relu = nn.ReLU(inplace=True)\n        self.stride = stride\n\n        self.bn1 = norm_func(bottleneck_channels)\n        self.bn2 = norm_func(bottleneck_channels)\n        self.bn3 = norm_func(out_channels)\n\n        self.use_cudnn = use_cudnn\n\n        # setup conv weights\n        self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]\n        if self.downsample is not None:\n            self.w_conv.append(self.downsample[0].weight)\n\n        # init weight in nchw format before possible transpose\n        for w in self.w_conv:\n            kaiming_uniform_(w, a=1)\n\n        # TODO: prevent unsupported case usage\n        # support cases\n        #                 native      cudnn\n        # normal             yes         no\n        # channel_last       yes        yes\n        # explicit_nhwc       no        yes\n        self.explicit_nhwc = explicit_nhwc\n        if self.explicit_nhwc:\n            for p in self.parameters():\n                with torch.no_grad():\n                    p.data = p.data.permute(0,2,3,1).contiguous()\n        return\n\n    def forward(self, x):\n        if self.use_cudnn:\n            # calculate scale/bias from registered buffers\n            # TODO: make this better\n            s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)\n            s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)\n            s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)\n            w_scale = [s1, s2, s3]\n            w_bias = [b1, b2, b3]\n            if self.downsample is not None:\n                s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)\n                w_scale.append(s4)\n                w_bias.append(b4)\n\n            out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)\n            return out\n\n        if self.explicit_nhwc:\n            raise RuntimeError('explicit nhwc with native ops is not supported.')\n\n        # fallback to native ops\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/bottleneck/test.py",
    "content": "import torch\nfrom bottleneck import Bottleneck\ntorch.manual_seed(23337)\n\n# use True to print layerwise sum for all outputs in reference code path\nDEBUG = False#True\n\nfor stride, o_channel in [(1,32), (1,128), (2,32)]:\n    print(\"testing stride ==\", stride, \", in_channel == 32 , out_channel ==\", o_channel)\n    a_ = torch.randn(17,32,28,28)\n\n    a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()\n    model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last)\n\n    # test model\n    b = model(a)\n    b.mean().backward()\n    d_grad = a.grad.float()\n    a.grad = None\n    torch.cuda.synchronize()\n\n    if DEBUG:\n        print(\"[DEBUG] ref dx :\", d_grad.sum().item())\n        # print wgrad. we don't need to reset since later cpp print before accumulation\n        for i, w in enumerate(model.w_conv):\n            print(\"[DEBUG] ref wgrad{} :\".format(i+1), w.grad.sum().item())\n\n    wgrads = []\n    for w in model.w_conv:\n        wgrads.append(w.grad.float())\n\n    model.use_cudnn = True\n    model.zero_grad()\n    c = model(a)\n    c.mean().backward()\n\n    torch.cuda.synchronize()\n    print(\"comparing native and channels_last:\")\n    print(\"max error fprop:\", (b-c).abs().max().item(), \"max elem:\", b.abs().max().item())\n    print(\"max error dgrad:\", (d_grad-a.grad.float()).abs().max().item(), \"max elem:\", d_grad.abs().max().item())\n    for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):\n        print(\"max error wgrad{}:\".format(i+1), (wgrad - w.grad.float()).abs().max().item(), \"max elem:\", wgrad.abs().max().item())\n\n    nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_()\n    nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half()\n    for p,q in zip(model.parameters(), nhwc_model.parameters()):\n        # model's storage is already in nhwc, we clone and assign to explicit nhwc model\n        q.data.copy_(p.data.permute(0,2,3,1).contiguous())\n    for p,q in zip(model.buffers(), nhwc_model.buffers()):\n        q.data.copy_(p.data)\n\n    d = nhwc_model(nhwc_a)\n    d.mean().backward()\n    torch.cuda.synchronize()\n\n    # reset reference to cudnn channels_last permute\n    #c_s = c.storage().tolist()\n    #d_s = d.storage().tolist()\n    #print(max([x-y for x,y in zip(c_s,d_s)]))\n    c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous()\n    d_grad = a.grad.float().permute(0,2,3,1).contiguous()\n    wgrads = []\n    for w in model.w_conv:\n        wgrads.append(w.grad.float().permute(0,2,3,1).contiguous())\n\n    torch.cuda.synchronize()\n    print(\"comparing nhwc and channels_last:\")\n    print(\"max error fprop:\", (d-c).abs().max().item(), \"max elem:\", c.abs().max().item())\n    print(\"max error dgrad:\", (d_grad-nhwc_a.grad.float()).abs().max().item(), \"max elem:\", d_grad.abs().max().item())\n    for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):\n        print(\"max error wgrad{}:\".format(i+1), (wgrad - w.grad.float()).abs().max().item(), \"max elem:\", wgrad.abs().max().item())\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/bottleneck/bottleneck.cpp",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cudnn/Handle.h>  // for getcudnnhandle\n#include <torch/extension.h>\n#include <torch/torch.h>\n#include <vector>\n#include <cudnn_frontend.h>\n\n#include <iostream>\n\n#ifdef DEBUG\n#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )\n#else\n#define DEBUG_MSG(str) do { } while ( false )\n#endif\n\n#ifdef DEBUG_CUDNN\n#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )\n#else\n#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )\n#endif\n\n#define checkCudnnErr(...)                                                        \\\n    do {                                                                          \\\n        int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \\\n        if (err) {                                                                \\\n            return;                                                    \\\n\t}                                                                         \\\n    } while (0)\n\n\nint checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {\n    if (code) {\n        printf(\"CUDNN error at %s:%d, code=%d (%s) in '%s'\\n\", file, line, (int)code, cudnnGetErrorString(code), expr);\n        return 1;\n    }\n    return 0;\n}\n\nvoid checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true);\n#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); }    // in-line regular function\n\nvoid checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort)\n{\n  if (code != cudaSuccess)\n  {\n    const char * errorMessage = cudaGetErrorString(code);\n    fprintf(stderr, \"CUDA error returned from \\\"%s\\\" at %s:%d, Error code: %d (%s)\\n\", func, file, line, code, errorMessage);\n    if (abort){\n      cudaDeviceReset();\n      exit(code);\n    }\n  }\n}\n\nvoid generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {\n    // For INT8x4 and INT8x32 we still compute standard strides here to input\n    // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.\n    if (filterFormat == CUDNN_TENSOR_NCHW) {\n        strideA[nbDims - 1] = 1;\n        for (int64_t d = nbDims - 2; d >= 0; d--) {\n            strideA[d] = strideA[d + 1] * dimA[d + 1];\n        }\n    } else {\n        // Here we assume that the format is CUDNN_TENSOR_NHWC\n\tstrideA[1]          = 1;\n        strideA[nbDims - 1] = strideA[1] * dimA[1];\n        for (int64_t d = nbDims - 2; d >= 2; d--) {\n            strideA[d] = strideA[d + 1] * dimA[d + 1];\n        }\n        strideA[0] = strideA[2] * dimA[2];\n    }\n}\n\n\nint getFwdConvDilatedFilterDim(int filterDim, int dilation) {\n    return ((filterDim - 1) * dilation) + 1;\n}\n\nint getFwdConvPaddedImageDim(int tensorDim, int pad) {\n    return tensorDim + (2 * pad);\n}\n\nint getFwdConvOutputDim(\n    int tensorDim,\n    int pad,\n    int filterDim,\n    int stride,\n    int dilation)\n{\n    int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;\n    return (p);\n}\n\nenum {\n    X_TENSOR,\n    Y_TENSOR,\n    W_TENSOR,\n    Z_TENSOR,\n    B_TENSOR,\n    AFTERADD_TENSOR,\n    AFTERBIAS_TENSOR,\n    AFTERCONV_TENSOR,\n    OPTIONAL,\n    AFTEROPT_TENSOR,\n};\n\nusing common_conv_descriptors =\n    std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::ConvDesc>;\n\n\ncommon_conv_descriptors\ncreate_common_descriptors(int64_t* x_dim_padded,\n                          int64_t* padA,\n                          int64_t* convstrideA,\n                          int64_t* dilationA,\n                          int64_t* w_dim_padded,\n                          int64_t* y_dim_padded,\n                          cudnnDataType_t dataType,\n                          cudnnConvolutionMode_t mode) {\n    const int convDim = 2;\n\n    int64_t strideA_padded[4];\n    int64_t outstrideA_padded[4];\n    int64_t filterstrideA_padded[4];\n\n    generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC);\n\n    return common_conv_descriptors(cudnn_frontend::TensorBuilder()\n                                       .setDim(4, x_dim_padded)\n                                       .setStrides(4, strideA_padded)\n                                       .setId('x')\n                                       .setAlignment(16)\n                                       .setDataType(dataType)\n                                       .build(),\n                                   cudnn_frontend::TensorBuilder()\n                                       .setDim(4, y_dim_padded)\n                                       .setStrides(4, outstrideA_padded)\n                                       .setId('y')\n                                       .setAlignment(16)\n                                       .setDataType(dataType)\n                                       .build(),\n                                   cudnn_frontend::TensorBuilder()\n                                       .setDim(4, w_dim_padded)\n                                       .setStrides(4, filterstrideA_padded)\n                                       .setId('w')\n                                       .setAlignment(16)\n                                       .setDataType(dataType)\n                                       .build(),\n                                   cudnn_frontend::ConvDescBuilder()\n                                       .setDataType(CUDNN_DATA_FLOAT)\n                                       .setMathMode(mode)\n                                       .setNDims(convDim)\n                                       .setStrides(convDim, convstrideA)\n                                       .setPrePadding(convDim, padA)\n                                       .setPostPadding(convDim, padA)\n                                       .setDilation(convDim, dilationA)\n                                       .build());\n}\n\nusing common_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor>;\n\ncommon_convbias_descriptors\ncreate_conv_bias_add_act_descriptors(int64_t* x_dim_padded,\n                                     int64_t* padA,\n                                     int64_t* convstrideA,\n                                     int64_t* dilationA,\n                                     int64_t* w_dim_padded,\n                                     int64_t* y_dim_padded,\n                                     cudnnDataType_t dataType) {\n    const int convDim = 2;\n\n    int64_t b_dim_padded[4];\n    b_dim_padded[0] = 1;\n    b_dim_padded[1] = y_dim_padded[1];\n    b_dim_padded[2] = 1;\n    b_dim_padded[3] = 1;\n\n    int64_t x_stride_padded[4];\n    int64_t y_stride_padded[4];\n    int64_t w_stride_padded[4];\n    int64_t b_stride_padded[4];\n\n    generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n\n    return common_convbias_descriptors(cudnn_frontend::TensorBuilder()\n                                           .setDim(4, x_dim_padded)\n                                           .setStrides(4, x_stride_padded)\n                                           .setId('x')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('y')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, w_dim_padded)\n                                           .setStrides(4, w_stride_padded)\n                                           .setId('w')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, b_dim_padded)\n                                           .setStrides(4, b_stride_padded)\n                                           .setId('z')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, b_dim_padded)\n                                           .setStrides(4, b_stride_padded)\n                                           .setId('b')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setVirtual()\n                                           .setId('A')  // after add\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setVirtual()\n                                           .setId('B')  // after bias\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('C')  // after conv\n                                           .setAlignment(16)\n                                           .setVirtual()\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('i')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('D')  // after optional add\n                                           .setAlignment(16)\n                                           .setVirtual()\n                                           .setDataType(dataType)\n                                           .build());\n}\n\n// tensor descriptors used for dgrad\nenum {\n    X_OR_DX_TENSOR,\n    DY_TENSOR,\n    W_OR_DW_TENSOR,\n    SCALE_TENSOR,\n    RELU_TENSOR,\n    AFTER_DCONV_TENSOR,\n    AFTER_DRELU_TENSOR,\n};\n\nusing dconv_descriptors = std::tuple<cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor>;\n\ndconv_descriptors\ncreate_dconv_descriptors(int64_t* x_dim_padded,\n                         int64_t* padA,\n                         int64_t* convstrideA,\n                         int64_t* dilationA,\n                         int64_t* w_dim_padded,\n                         int64_t* y_dim_padded,\n                         cudnnDataType_t dataType) {\n    const int convDim = 2;\n\n    int64_t b_dim_padded[4];\n    b_dim_padded[0] = 1;\n    b_dim_padded[1] = x_dim_padded[1];\n    b_dim_padded[2] = 1;\n    b_dim_padded[3] = 1;\n\n    int64_t x_stride_padded[4];\n    int64_t y_stride_padded[4];\n    int64_t w_stride_padded[4];\n    int64_t b_stride_padded[4];\n\n    generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n\n    return dconv_descriptors(cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setId('x')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, y_dim_padded)\n                             .setStrides(4, y_stride_padded)\n                             .setId('y')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, w_dim_padded)\n                             .setStrides(4, w_stride_padded)\n                             .setId('w')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, b_dim_padded)\n                             .setStrides(4, b_stride_padded)\n                             .setId('s')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setId('r')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setVirtual()\n                             .setId('A')  // after dconv\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setVirtual()\n                             .setId('B')  // after drelu\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build());\n}\n\n// create a cache for plan\nstd::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;\n\n// TODO: better name\nstd::string getConvFusionString(int64_t* x_dim_padded,\n                                int64_t* padA,\n                                int64_t* convstrideA,\n                                int64_t* dilationA,\n                                int64_t* w_dim_padded,\n                                cudnnDataType_t dataType,\n                                std::string fusion_string) {\n\n  for(int i=0;i<4;i++) {\n    fusion_string += 'X';\n    fusion_string += std::to_string(x_dim_padded[i]);\n  }\n  for(int i=0;i<4;i++) {\n    fusion_string += 'W';\n    fusion_string += std::to_string(w_dim_padded[i]);\n  }\n  for(int i=0;i<2;i++) {\n    fusion_string += 'P';\n    fusion_string += std::to_string(padA[i]);\n  }\n  for(int i=0;i<2;i++) {\n    fusion_string += 'S';\n    fusion_string += std::to_string(convstrideA[i]);\n  }\n  for(int i=0;i<2;i++) {\n    fusion_string += 'D';\n    fusion_string += std::to_string(dilationA[i]);\n  }\n  fusion_string += 'T';\n  fusion_string += std::to_string(dataType);\n  return fusion_string;\n}\n\ncudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_,\n                                               std::stringstream& log_buf,\n                                               cudnn_frontend::OperationGraph& opGraph,\n                                               std::string cache_string,\n                                               bool use_heuristic = true){\n  auto it = plan_cache.find(cache_string);\n  if (it != plan_cache.end()) {\n    DEBUG_CUDNN_MSG(log_buf, \"Found plan in cache\");\n    return it->second;\n  } else {\n    if (use_heuristic){\n      // TODO: confirm which mode to use\n      auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()\n        .setOperationGraph(opGraph)\n        .setHeurMode(CUDNN_HEUR_MODE_INSTANT)\n        .build();\n      // try 3 times for now as WAR for no heuristic training\n      int max_tries = 3, count = 0;\n      auto& engine_configs = heuristics.getEngineConfig(max_tries);\n      while(true) {\n        try {\n          plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()\n                                                     .setHandle(handle_)\n                                                     .setEngineConfig(engine_configs[count], opGraph.getTag())\n                                                     .build()));\n          break;\n        } catch (cudnn_frontend::cudnnException e) {\n          if (++count == max_tries) throw e;\n        }\n      }\n    }else{\n    DEBUG_CUDNN_MSG(log_buf, \"No plan in cache\");\n    // How many engines support this operation graph ?\n    auto total_engines = opGraph.getEngineCount();\n    DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << \" has \" << total_engines << \" engines.\");\n    // We have to randomly pick one engine from [0, total_engines)\n    // Selecting \"0\" by default\n    auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();\n    DEBUG_CUDNN_MSG(log_buf, engine.describe());\n    auto& knobs = engine.getSupportedKnobs();\n    for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {\n      DEBUG_CUDNN_MSG(log_buf, it->describe());\n    }\n    if (knobs.begin() != knobs.end()) {\n      DEBUG_CUDNN_MSG(log_buf, \"Updated knob choice\");\n      knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);\n      DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());\n    }\n\n    // Createmplacee the requisite engine config\n    auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();\n    DEBUG_CUDNN_MSG(log_buf, engine_config.describe());\n    plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));\n    }\n\n    return plan_cache.find(cache_string)->second;\n  }\n}\n\nvoid\nrun_conv_scale_bias_add_activation(int64_t* x_dim_padded,\n                                   int64_t* pad,\n                                   int64_t* convstride,\n                                   int64_t* dilation,\n                                   int64_t* w_dim_padded,\n                                   int64_t* y_dim_padded,\n                                   cudnnDataType_t dataType,\n                                   at::Half* devPtrX,\n                                   at::Half* devPtrW,\n                                   at::Half* devPtrY,\n                                   at::Half* devPtrZ,\n                                   at::Half* devPtrB,\n                                   at::Half* devPtrI) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n\n        // Define the add operation\n        auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()\n                           .setMode(CUDNN_POINTWISE_MUL)\n                           .setMathPrecision(CUDNN_DATA_FLOAT)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n        // Define the bias operation\n        auto biasDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_ADD)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n        // optional add\n        auto addDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_ADD)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n        // Define the activation operation\n        auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                           .setMode(CUDNN_POINTWISE_RELU_FWD)\n                           .setMathPrecision(CUDNN_DATA_FLOAT)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                           .setxDesc(std::get<X_TENSOR>(tensors))\n                           .setwDesc(std::get<W_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                           .setcDesc(convDesc)\n                           .setAlpha(alpha)\n                           .setBeta(beta)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // Create a Add Node with scaling parameters.\n        auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(conv_op.getOutputTensor())\n                           .setbDesc(std::get<Z_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERADD_TENSOR>(tensors))\n                           .setpwDesc(scaleDesc)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n        // Create a Bias Node.\n        auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(scale_op.getOutputTensor())\n                           .setbDesc(std::get<B_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))\n                           .setpwDesc(biasDesc)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n        // Create a optional add Node.\n        auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(bias_op.getOutputTensor())\n                           .setbDesc(std::get<OPTIONAL>(tensors))\n                           .setyDesc(std::get<AFTEROPT_TENSOR>(tensors))\n                           .setpwDesc(addDesc)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n\n        // Create an Activation Node.\n        auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())\n                          .setyDesc(std::get<Y_TENSOR>(tensors))\n                          .setpwDesc(actDesc)\n                          .build();\n        DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(devPtrI ? ops.size() : 4, ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};\n        int64_t uids[]    = {'x', 'y', 'w', 'z', 'b', 'i'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n                               .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(devPtrI ? 6 : 5, data_ptrs)\n          .setUids(devPtrI ? 6 : 5, uids)\n                               .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\nvoid\nrun_conv_scale_bias(int64_t* x_dim_padded,\n                    int64_t* pad,\n                    int64_t* convstride,\n                    int64_t* dilation,\n                    int64_t* w_dim_padded,\n                    int64_t* y_dim_padded,\n                    cudnnDataType_t dataType,\n                    at::Half* devPtrX,\n                    at::Half* devPtrW,\n                    at::Half* devPtrY,\n                    at::Half* devPtrZ,\n                    at::Half* devPtrB) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n\n        // Define the add operation\n        auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_MUL)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n        // Define the bias operation\n        auto addDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_ADD)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                           .setxDesc(std::get<X_TENSOR>(tensors))\n                           .setwDesc(std::get<W_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                           .setcDesc(convDesc)\n                           .setAlpha(alpha)\n                           .setBeta(beta)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // Create a Add Node with scaling parameters.\n        auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(conv_op.getOutputTensor())\n          .setbDesc(std::get<Z_TENSOR>(tensors))\n          .setyDesc(std::get<AFTERADD_TENSOR>(tensors)) // TODO: change enum to aftermul\n          .setpwDesc(scaleDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n        // Create a Bias Node.\n        auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(scale_op.getOutputTensor())\n          .setbDesc(std::get<B_TENSOR>(tensors))\n          .setyDesc(std::get<Y_TENSOR>(tensors))\n          .setpwDesc(addDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &scale_op, &add_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB};\n        int64_t uids[]    = {'x', 'y', 'w', 'z', 'b'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n                               .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(5, data_ptrs)\n          .setUids(5, uids)\n                               .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\n\nvoid\nrun_dconv_drelu_dscale(int64_t* x_dim_padded,\n                       int64_t* pad,\n                       int64_t* convstride,\n                       int64_t* dilation,\n                       int64_t* w_dim_padded,\n                       int64_t* y_dim_padded,\n                       cudnnDataType_t dataType,\n                       at::Half* devPtrX,\n                       at::Half* devPtrW,\n                       at::Half* devPtrY,\n                       at::Half* devPtrZ,\n                       at::Half* devPtrR) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        dconv_descriptors tensors = create_dconv_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        // Define the activation backward operation\n        auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_RELU_BWD)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n        // Define the scale backward operation\n        auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_MUL)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n          .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n          .setdyDesc(std::get<DY_TENSOR>(tensors))\n          .setcDesc(convDesc)\n          .setAlpha(alpha)\n          .setBeta(beta)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // TODO: do we need getOutputTensor(), and what it returns in backward case?\n        // Create an relu backward Node.\n        auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setxDesc(std::get<RELU_TENSOR>(tensors))\n          .setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n          .setpwDesc(actDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n        // Create a Scale Node.\n        auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n          .setbDesc(std::get<SCALE_TENSOR>(tensors))\n          .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n          .setpwDesc(scaleDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &scale_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR};\n        int64_t uids[]    = {'x', 'y', 'w', 's', 'r'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n          .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(5, data_ptrs)\n          .setUids(5, uids)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\nvoid\nrun_dconv(int64_t* x_dim_padded,\n          int64_t* pad,\n          int64_t* convstride,\n          int64_t* dilation,\n          int64_t* w_dim_padded,\n          int64_t* y_dim_padded,\n          cudnnDataType_t dataType,\n          at::Half* devPtrX,\n          at::Half* devPtrW,\n          at::Half* devPtrY,\n          cudnnBackendDescriptorType_t mode) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        dconv_descriptors tensors = create_dconv_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        // mode should be one of following\n        // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR\n        // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR\n        auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);\n        if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {\n          conv_op_builder.setdxDesc(std::get<X_OR_DX_TENSOR>(tensors))\n            .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n            .setdyDesc(std::get<DY_TENSOR>(tensors))\n            .setcDesc(convDesc)\n            .setAlpha(alpha)\n            .setBeta(beta);\n        }\n        else {\n          conv_op_builder.setxDesc(std::get<X_OR_DX_TENSOR>(tensors))\n            .setdwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n            .setdyDesc(std::get<DY_TENSOR>(tensors))\n            .setcDesc(convDesc)\n            .setAlpha(alpha)\n            .setBeta(beta);\n        }\n        auto conv_op = conv_op_builder.build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW};\n        int64_t uids[]    = {'x', 'y', 'w'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n          .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(3, data_ptrs)\n          .setUids(3, uids)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\nvoid\nrun_dconv_add(int64_t* x_dim_padded,\n              int64_t* pad,\n              int64_t* convstride,\n              int64_t* dilation,\n              int64_t* w_dim_padded,\n              int64_t* y_dim_padded,\n              cudnnDataType_t dataType,\n              at::Half* devPtrX,\n              at::Half* devPtrW,\n              at::Half* devPtrY,\n              at::Half* devPtrR) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        dconv_descriptors tensors = create_dconv_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        // Define the add backward operation\n        auto addDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_ADD)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n          .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n          .setdyDesc(std::get<DY_TENSOR>(tensors))\n          .setcDesc(convDesc)\n          .setAlpha(alpha)\n          .setBeta(beta)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // TODO: do we need getOutputTensor(), and what it returns in backward case?\n        // Create add Node.\n        auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setbDesc(std::get<RELU_TENSOR>(tensors))\n          .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n          .setpwDesc(addDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &add_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR};\n        int64_t uids[]    = {'x', 'y', 'w', 'r'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n          .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(4, data_ptrs)\n          .setUids(4, uids)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\n\n// inputs contains x,w,z,b,(i)\nstd::vector<at::Tensor> bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n\n  std::cout << std::fixed;\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t dimA[]         = {0, 0, 0, 0};\n  int64_t filterdimA1[]  = {0, 0, 0, 0};\n  int64_t filterdimA2[]  = {0, 0, 0, 0};\n  int64_t filterdimA3[]  = {0, 0, 0, 0};\n  int64_t filterdimA4[]  = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] {0,1,2,3};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 3;\n    axis[2] = 1;\n    axis[3] = 2;\n  }\n  for (int dim=0;dim<4;dim++) {\n    dimA[dim] = inputs[0].size(axis[dim]);\n    filterdimA1[dim] = inputs[1].size(axis[dim]);\n    filterdimA2[dim] = inputs[2].size(axis[dim]);\n    filterdimA3[dim] = inputs[3].size(axis[dim]);\n  }\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    for (int dim=0;dim<4;dim++) {\n      filterdimA4[dim] = inputs[10].size(axis[dim]);\n    }\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t outdimA1[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA2[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA3[]     = {0, 0, 0, 0}; // Computed Below\n\n  // use these fixed value for test run\n  int64_t padA[]        = {0, 0};\n  int64_t padA1[]        = {1, 1};\n  int64_t dilationA[] = {1, 1};\n  int64_t convstrideA[] = {1, 1};\n  int64_t convstride1X1[] = {stride_1X1, stride_1X1};\n\n  // compute output from pad/stride/dilation\n  outdimA1[0] = dimA[0];\n  outdimA1[1] = filterdimA1[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n  }\n\n  outdimA2[0] = outdimA1[0];\n  outdimA2[1] = filterdimA2[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  outdimA3[0] = outdimA2[0];\n  outdimA3[1] = filterdimA3[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  // Create output tensor in the correct shape in pytorch's view\n  int64_t outdim1[]     = {0, 0, 0, 0};\n  int64_t outdim2[]     = {0, 0, 0, 0};\n  int64_t outdim3[]     = {0, 0, 0, 0};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 2;\n    axis[2] = 3;\n    axis[3] = 1;\n  }\n  for (int dim=0;dim<4;dim++) {\n    outdim1[dim] = outdimA1[axis[dim]];\n    outdim2[dim] = outdimA2[axis[dim]];\n    outdim3[dim] = outdimA3[axis[dim]];\n  }\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n  at::Half* b = inputs[7].data_ptr<at::Half>();\n  auto out1 = at::empty(outdim1, inputs[0].type(), output_format);\n  at::Half* y1 = out1.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(dimA,\n                                     padA,\n                                     convstride1X1,\n                                     dilationA,\n                                     filterdimA1,\n                                     outdimA1,\n                                     CUDNN_DATA_HALF,\n                                     x,\n                                     w,\n                                     y1,\n                                     z,\n                                     b,\n                                     nullptr);\n\n  DEBUG_MSG(\"[DEBUG] new relu1 : \" << out1.to(at::kFloat).sum().item<float>());\n\n  w = inputs[2].data_ptr<at::Half>();\n  z = inputs[5].data_ptr<at::Half>();\n  b = inputs[8].data_ptr<at::Half>();\n  auto out2 = at::empty(outdim2, inputs[0].type(), output_format);\n  at::Half* y2 = out2.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(outdimA1,\n                                     padA1,\n                                     convstrideA,\n                                     dilationA,\n                                     filterdimA2,\n                                     outdimA2,\n                                     CUDNN_DATA_HALF,\n                                     y1,\n                                     w,\n                                     y2,\n                                     z,\n                                     b,\n                                     nullptr);\n  DEBUG_MSG(\"[DEBUG] new relu2 : \" << out2.to(at::kFloat).sum().item<float>());\n\n  // create output of conv3\n  auto out3 = at::empty(outdim3, inputs[0].type(), output_format);\n  at::Half* y3 = out3.data_ptr<at::Half>();\n\n  // create output of conv4 that may exist\n  auto identity = at::empty_like(out3);\n  at::Half* yi = identity.data_ptr<at::Half>();\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){\n\n    w = inputs[10].data_ptr<at::Half>();\n    z = inputs[11].data_ptr<at::Half>();\n    b = inputs[12].data_ptr<at::Half>();\n    run_conv_scale_bias(dimA,\n                        padA,\n                        convstride1X1,\n                        dilationA,\n                        filterdimA4,\n                        outdimA3,\n                        CUDNN_DATA_HALF,\n                        x,\n                        w,\n                        yi,\n                        z,\n                        b);\n    DEBUG_MSG(\"[DEBUG] new downsample : \" << identity.to(at::kFloat).sum().item<float>());\n  }\n  else {\n    yi = x;\n  }\n\n  w = inputs[3].data_ptr<at::Half>();\n  z = inputs[6].data_ptr<at::Half>();\n  b = inputs[9].data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(outdimA2,\n                                     padA,\n                                     convstrideA,\n                                     dilationA,\n                                     filterdimA3,\n                                     outdimA3,\n                                     CUDNN_DATA_HALF,\n                                     y2,\n                                     w,\n                                     y3,\n                                     z,\n                                     b,\n                                     yi);\n  DEBUG_MSG(\"[DEBUG] new relu3 : \" << out3.to(at::kFloat).sum().item<float>());\n\n  outputs.push_back(out1);\n  outputs.push_back(out2);\n  outputs.push_back(out3);\n\n  return outputs;\n}\n\nstd::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t dimA[]         = {0, 0, 0, 0};\n  int64_t filterdimA1[]  = {0, 0, 0, 0};\n  int64_t filterdimA2[]  = {0, 0, 0, 0};\n  int64_t filterdimA3[]  = {0, 0, 0, 0};\n  int64_t filterdimA4[]  = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] {0,1,2,3};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 3;\n    axis[2] = 1;\n    axis[3] = 2;\n  }\n  for (int dim=0;dim<4;dim++) {\n    dimA[dim] = inputs[0].size(axis[dim]);\n    filterdimA1[dim] = inputs[1].size(axis[dim]);\n    filterdimA2[dim] = inputs[2].size(axis[dim]);\n    filterdimA3[dim] = inputs[3].size(axis[dim]);\n  }\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    for (int dim=0;dim<4;dim++) {\n      filterdimA4[dim] = inputs[14].size(axis[dim]);\n    }\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t outdimA1[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA2[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA3[]     = {0, 0, 0, 0}; // Computed Below\n\n  // use these fixed value for test run\n  int64_t padA[]        = {0, 0};\n  int64_t padA1[]        = {1, 1};\n  int64_t dilationA[] = {1, 1};\n  int64_t convstrideA[] = {1, 1};\n  int64_t convstride1X1[] = {stride_1X1, stride_1X1};\n\n  // compute output from pad/stride/dilation\n  outdimA1[0] = dimA[0];\n  outdimA1[1] = filterdimA1[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n  }\n\n  outdimA2[0] = outdimA1[0];\n  outdimA2[1] = filterdimA2[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  outdimA3[0] = outdimA2[0];\n  outdimA3[1] = filterdimA3[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  // Create output tensor in the correct shape in pytorch's view\n  int64_t outdim1[]     = {0, 0, 0, 0};\n  int64_t outdim2[]     = {0, 0, 0, 0};\n  int64_t outdim3[]     = {0, 0, 0, 0};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 2;\n    axis[2] = 3;\n    axis[3] = 1;\n  }\n  for (int dim=0;dim<4;dim++) {\n    outdim1[dim] = outdimA1[axis[dim]];\n    outdim2[dim] = outdimA2[axis[dim]];\n    outdim3[dim] = outdimA3[axis[dim]];\n  }\n\n  // dconv3+drelu2+dscale2\n  at::Half* conv_in = inputs[13].data_ptr<at::Half>();\n  at::Half* dy3 = inputs[10].data_ptr<at::Half>();\n\n  DEBUG_MSG(\"[DEBUG] new dconv3 : \" << inputs[10].to(at::kFloat).sum().item<float>());\n\n  // wgrad\n  auto wgrad3 = at::empty_like(inputs[3]);\n  at::Half* dw3 = wgrad3.data_ptr<at::Half>();\n  run_dconv(outdimA2,\n            padA,\n            convstrideA,\n            dilationA,\n            filterdimA3,\n            outdimA3,\n            CUDNN_DATA_HALF,\n            conv_in,\n            dw3,\n            dy3,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format);\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n  at::Half* w = inputs[3].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n\n  at::Half* relu2 = inputs[13].data_ptr<at::Half>();\n\n  run_dconv_drelu_dscale(outdimA2,\n                         padA,\n                         convstrideA,\n                         dilationA,\n                         filterdimA3,\n                         outdimA3,\n                         CUDNN_DATA_HALF,\n                         dy2,\n                         w,\n                         dy3,\n                         z,\n                         relu2);\n\n  DEBUG_MSG(\"[DEBUG] new dconv2 : \" << grad_out2.to(at::kFloat).sum().item<float>());\n\n  // dconv2+drelu1+dscale1\n  conv_in = inputs[12].data_ptr<at::Half>();\n\n  // wgrad\n  auto wgrad2 = at::empty_like(inputs[2]);\n  at::Half* dw2 = wgrad2.data_ptr<at::Half>();\n  run_dconv(outdimA1,\n            padA1,\n            convstrideA,\n            dilationA,\n            filterdimA2,\n            outdimA2,\n            CUDNN_DATA_HALF,\n            conv_in,\n            dw2,\n            dy2,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format);\n  at::Half* dy1 = grad_out1.data_ptr<at::Half>();\n  w = inputs[2].data_ptr<at::Half>();\n  z = inputs[4].data_ptr<at::Half>();\n\n  at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n  // fused dgrad\n  run_dconv_drelu_dscale(outdimA1,\n                         padA1,\n                         convstrideA,\n                         dilationA,\n                         filterdimA2,\n                         outdimA2,\n                         CUDNN_DATA_HALF,\n                         dy1,\n                         w,\n                         dy2,\n                         z,\n                         relu1);\n\n/*\n  // backward strided conv cannot be fused\n  // if stride == 1 but channel changes, we can fuse here\n  if (stride_1X1 != 1){\n    // dgrad\n    run_dconv(outdimA1,\n              padA1,\n              convstride1X1,\n              dilationA,\n              filterdimA2,\n              outdimA2,\n              CUDNN_DATA_HALF,\n              dy1,\n              w,\n              dy2,\n              CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n\n    // mul fused mask\n    grad_out1.mul_(inputs[15]);\n  }\n  else {\n    at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n    // fused dgrad\n    run_dconv_drelu_dscale(outdimA1,\n                           padA1,\n                           convstride1X1,\n                           dilationA,\n                           filterdimA2,\n                           outdimA2,\n                           CUDNN_DATA_HALF,\n                           dy1,\n                           w,\n                           dy2,\n                           z,\n                           relu1);\n  }\n*/\n  DEBUG_MSG(\"[DEBUG] new dconv1 : \" << grad_out1.to(at::kFloat).sum().item<float>());\n\n  // create grads of conv4 that may exist\n  auto grad_x_conv4 = at::empty_like(inputs[0]);\n  at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();\n  at::Tensor wgrad4;\n\n  // x used for dconv1 and dconv4 wgrad\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){\n    w = inputs[14].data_ptr<at::Half>();\n    at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();\n    if (requires_grad) {\n      run_dconv(dimA,\n                padA,\n                convstride1X1,\n                dilationA,\n                filterdimA4,\n                outdimA3,\n                CUDNN_DATA_HALF,\n                dx_conv4,\n                w,\n                dy_conv4,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx\n      // DEBUG_MSG(\"[DEBUG] new dx_identity : \" << grad_x_conv4.to(at::kFloat).sum().item<float>());\n    }\n    // wgrad\n    wgrad4 = at::empty_like(inputs[14]);\n    at::Half* dw4 = wgrad4.data_ptr<at::Half>();\n    run_dconv(dimA,\n              padA,\n              convstride1X1,\n              dilationA,\n              filterdimA4,\n              outdimA3,\n              CUDNN_DATA_HALF,\n              x,\n              dw4,\n              dy_conv4,\n              CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n  }\n  else {\n    // if there is no downsample, dx_conv4 is fork of drelu3\n    dx_conv4 = inputs[11].data_ptr<at::Half>();\n  }\n\n  // dconv1+add\n  // wgrad\n  auto wgrad1 = at::empty_like(inputs[1]);\n  at::Half* dw1 = wgrad1.data_ptr<at::Half>();\n  run_dconv(dimA,\n            padA,\n            convstride1X1,\n            dilationA,\n            filterdimA1,\n            outdimA1,\n            CUDNN_DATA_HALF,\n            x,\n            dw1,\n            dy1,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  w = inputs[1].data_ptr<at::Half>();\n  auto grad_x = at::empty_like(inputs[0]);\n  at::Half* dx = grad_x.data_ptr<at::Half>();\n\n  // backward strided conv cannot be fused\n  // if stride == 1 but channel changes, we can fuse here\n  if (requires_grad){\n    if (stride_1X1 != 1){\n      run_dconv(dimA,\n                padA,\n                convstride1X1,\n                dilationA,\n                filterdimA1,\n                outdimA1,\n                CUDNN_DATA_HALF,\n                dx,\n                w,\n                dy1,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // add 2 together\n      grad_x.add_(grad_x_conv4);\n    }\n    else {\n      run_dconv_add(dimA,\n                    padA,\n                    convstride1X1,\n                    dilationA,\n                    filterdimA1,\n                    outdimA1,\n                    CUDNN_DATA_HALF,\n                    dx,\n                    w,\n                    dy1,\n                    dx_conv4);\n    }\n  }\n\n  DEBUG_MSG(\"[DEBUG] new dx : \" << grad_x.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad1 : \" << wgrad1.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad2 : \" << wgrad2.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad3 : \" << wgrad3.to(at::kFloat).sum().item<float>());\n  outputs.push_back(grad_x);\n  outputs.push_back(wgrad1);\n  outputs.push_back(wgrad2);\n  outputs.push_back(wgrad3);\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    DEBUG_MSG(\"[DEBUG] new wgrad4 : \" << wgrad4.to(at::kFloat).sum().item<float>());\n    outputs.push_back(wgrad4);\n  }\n\n  return outputs;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &bottleneck_forward, \"Bottleneck block forward\");\n  m.def(\"backward\", &bottleneck_backward, \"Bottleneck block backward\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/fmha_api.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include <torch/extension.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include \"fmha.h\"\n\nvoid set_params(Fused_multihead_attention_fprop_params &params,\n                // sizes\n                const size_t b,\n                const size_t s,\n                const size_t h,\n                const size_t d,\n                // device pointers\n                void *qkv_packed_d,\n                void *cu_seqlens_d,\n                void *o_packed_d,\n                void *s_d,\n                float p_dropout) {\n\n    Data_type acc_type = DATA_TYPE_FP32;\n    Data_type data_type = DATA_TYPE_FP16;\n\n    // Reset the parameters\n    memset(&params, 0, sizeof(params));\n\n    // Set the pointers and strides.\n    params.qkv_ptr = qkv_packed_d;\n    params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);\n    params.o_ptr = o_packed_d;\n    params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);\n\n    params.cu_seqlens = static_cast<int *>(cu_seqlens_d);\n\n    // S = softmax(P)\n    params.s_ptr = s_d;\n    params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);\n\n    // Set the dimensions.\n    params.b = b;\n    params.h = h;\n    params.s = s;\n    params.d = d;\n\n    // Set the different scale values.\n    const float scale_bmm1 = 1.f / sqrtf(d);\n    constexpr float scale_softmax = 1.f;\n    constexpr float scale_bmm2 = 1.f;\n\n    set_alpha(params.scale_bmm1, scale_bmm1, acc_type);\n    set_alpha(params.scale_softmax, scale_softmax, acc_type);\n    set_alpha(params.scale_bmm2, scale_bmm2, data_type);\n\n    // Set this to probability of keeping an element to simplify things.\n    params.p_dropout = 1.f - p_dropout;\n    params.rp_dropout = 1.f / params.p_dropout;\n    TORCH_CHECK(p_dropout < 1.f);\n    set_alpha(params.scale_dropout, params.rp_dropout, data_type);\n}\n\nstd::vector<at::Tensor>\nmha_fwd(const at::Tensor &qkv,  // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n        const at::Tensor &cu_seqlens,  // b+1\n        const float p_dropout,\n        const int max_seq_len,\n        const bool is_training,\n        c10::optional<at::Generator> gen_) {\n    auto dprops = at::cuda::getCurrentDeviceProperties();\n    TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);\n    int seq_len = 512;\n    auto launch = &run_fmha_fp16_512_64_sm80;\n    if( max_seq_len <= 128 ) {\n        seq_len = 128;\n        launch = &run_fmha_fp16_128_64_sm80;\n    } else if( max_seq_len <= 256 ) {\n        seq_len = 256;\n        launch = &run_fmha_fp16_256_64_sm80;\n    } else if( max_seq_len <= 384 ) {\n        seq_len = 384;\n        launch = &run_fmha_fp16_384_64_sm80;\n    } else if( max_seq_len <= 512 ) {\n        seq_len = 512;\n        launch = &run_fmha_fp16_512_64_sm80;\n    } else {\n        TORCH_CHECK(false);\n    }\n\n    constexpr int warps_m = 1;\n    constexpr int warps_n = 4;  // this leads to an upper bound\n    const int mmas_m = seq_len / 16 / warps_m;\n    const int mmas_n = seq_len / 16 / warps_n;\n    \n    const int elts_per_thread = 8 * mmas_m * mmas_n;\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.dtype() == torch::kFloat16);\n    TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);\n\n    TORCH_CHECK(qkv.is_cuda())\n    TORCH_CHECK(cu_seqlens.is_cuda())\n\n    TORCH_CHECK(qkv.is_contiguous())\n    TORCH_CHECK(cu_seqlens.is_contiguous())\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    const int total = sizes[TOTAL_DIM];\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n    auto opts = qkv.options();\n\n    auto ctx = torch::empty({ total, num_heads, head_size }, opts);\n\n    auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               ctx.data_ptr(),\n               s.data_ptr(),\n               p_dropout);\n\n    // number of times random will be generated per thread, to offset philox counter in thc random\n    // state\n    int64_t counter_offset = elts_per_thread;\n    at::PhiloxCudaState rng_engine_inputs;\n\n    if( is_training ) {\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n    }\n\n    launch(params, is_training, stream);\n\n    return { ctx, s };\n}\n\nstd::vector<at::Tensor>\nmha_bwd(const at::Tensor &dout,  // total x num_heads, x head_size\n        const at::Tensor &qkv,   // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n        at::Tensor &softmax,     // b x h x s x s softmax and dmask - will be overwritten with dP\n        const at::Tensor &cu_seqlens,  // b+1\n        const float p_dropout,         // probability to drop\n        const int max_seq_len          // max sequence length to choose the kernel\n) {\n    auto dprops = at::cuda::getCurrentDeviceProperties();\n    TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);\n    int seq_len = 512;\n    auto launch = &run_fmha_dgrad_fp16_512_64_sm80;\n    if( max_seq_len <= 128 ) {\n        seq_len = 128;\n        launch = &run_fmha_dgrad_fp16_128_64_sm80;\n    } else if( max_seq_len <= 256 ) {\n        seq_len = 256;\n        launch = &run_fmha_dgrad_fp16_256_64_sm80;\n    } else if( max_seq_len <= 384 ) {\n        seq_len = 384;\n        launch = &run_fmha_dgrad_fp16_384_64_sm80;\n    } else if( max_seq_len <= 512 ) {\n        seq_len = 512;\n        launch = &run_fmha_dgrad_fp16_512_64_sm80;\n    } else {\n        TORCH_CHECK(false);\n    }\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.dtype() == torch::kFloat16);\n    TORCH_CHECK(dout.dtype() == torch::kFloat16);\n    TORCH_CHECK(softmax.dtype() == torch::kFloat16);\n    TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);\n\n    TORCH_CHECK(qkv.is_cuda());\n    TORCH_CHECK(cu_seqlens.is_cuda());\n\n    TORCH_CHECK(qkv.is_contiguous());\n    TORCH_CHECK(cu_seqlens.is_contiguous());\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n\n    auto dqkv = torch::empty_like(qkv);\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               dout.data_ptr(),     // we set o_ptr to dout\n               softmax.data_ptr(),  // softmax gets overwritten by dP!\n               p_dropout);\n\n    // we're re-using these scales\n    Data_type acc_type = DATA_TYPE_FP32;\n    set_alpha(params.scale_bmm1, 1.f, acc_type);\n    set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);\n    set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);\n    params.dqkv_ptr = dqkv.data_ptr();\n\n    launch(params, stream);\n    return { dqkv, softmax };\n}\n\nstd::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n                                const at::Tensor &cu_seqlens,  // b+1\n                                const float p_dropout,\n                                const int max_seq_len,\n                                const bool is_training,\n                                c10::optional<at::Generator> gen_) {\n    int seq_len = 512;\n    auto launch = &run_fmha_fp16_512_64_sm80_nl;\n    TORCH_CHECK(max_seq_len == seq_len);\n\n    constexpr int warps_m = 1;\n    constexpr int warps_n = 4;  // this leads to an upper bound\n    const int mmas_m = seq_len / 16 / warps_m;\n    const int mmas_n = seq_len / 16 / warps_n;\n    // static_assert( mmas_m == 32 );\n    // static_assert( mmas_n == 4 );\n    const int elts_per_thread = 8 * mmas_m * mmas_n;\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.is_cuda())\n    TORCH_CHECK(cu_seqlens.is_cuda())\n\n    TORCH_CHECK(qkv.is_contiguous())\n    TORCH_CHECK(cu_seqlens.is_contiguous())\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    const int total = sizes[TOTAL_DIM];\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n    auto opts = qkv.options();\n\n    auto ctx = torch::empty({ total, num_heads, head_size }, opts);\n\n    auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               ctx.data_ptr(),\n               s.data_ptr(),\n               p_dropout);\n\n    // number of times random will be generated per thread, to offset philox counter in thc random\n    // state\n    int64_t counter_offset = elts_per_thread;\n    at::PhiloxCudaState rng_engine_inputs;\n\n    if( is_training ) {\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n    }\n    int num_chunks = 3;\n    if(batch_size == 3) {\n        num_chunks = 2;\n    }\n\n    launch(params, is_training, num_chunks, stream);\n\n    return { ctx, s };\n}\n\nstd::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout,        // total x num_heads, x head_size\n                                const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n                                at::Tensor &softmax,           // b x h x s x s softmax and dmask - will be overwritten with dP\n                                const at::Tensor &cu_seqlens,  // b+1\n                                const float p_dropout,         // probability to drop\n                                const int max_seq_len          // max sequence length to choose the kernel\n) {\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.is_cuda())\n    TORCH_CHECK(cu_seqlens.is_cuda())\n\n    TORCH_CHECK(qkv.is_contiguous())\n    TORCH_CHECK(cu_seqlens.is_contiguous())\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    \n    const int total = sizes[TOTAL_DIM];\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n\n    int seq_len = 512;\n    auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;\n\n    auto opts = qkv.options();\n\n    auto dqkv = torch::empty_like(qkv);\n\n    int num_chunks = 2;\n    if( batch_size == 1 ) {\n        num_chunks = 4;\n    }else if( batch_size == 2 ) {\n        num_chunks = 3;\n    }\n    auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               dout.data_ptr(),     // o_ptr = dout\n               softmax.data_ptr(),  // softmax gets overwritten by dP!\n               p_dropout);\n\n    params.dkv_ptr = dkv.data_ptr();\n\n    Data_type acc_type = DATA_TYPE_FP32;\n    set_alpha(params.scale_bmm1, 1.f, acc_type);\n    set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);\n    set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);\n    params.dqkv_ptr = dqkv.data_ptr();\n\n    launch(params, num_chunks, stream);\n\n    //SPLIT-K reduction of num_chunks dK, dV parts\n\n    // The equivalent of the following Pytorch code:\n    // using namespace torch::indexing;\n    // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});\n    // torch::sum_out(view_out, dkv, 1);\n\n    const int hidden_size = num_heads * head_size;\n    fmha_run_noloop_reduce(\n        dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);\n\n    return { dqkv, softmax, dkv };\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.doc() = \"Fused Multi-head Self-attention for BERT\";  \n    m.def(\"fwd\", &mha_fwd, \"Forward pass\");\n    m.def(\"bwd\", &mha_bwd, \"Backward pass\");\n    m.def(\"fwd_nl\", &mha_fwd_nl, \"Forward pass (small-batch)\");\n    m.def(\"bwd_nl\", &mha_bwd_nl, \"Backward pass (small-batch)\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha/gemm.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/utils.h>\n\n#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n))\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >\nstruct Fragment_base_ {\n\n    // The data type.\n    using Data_type = Data_type_;\n    // default input type\n    using Input_type_ = Data_type_;\n    // Does it store the array of elements.\n    enum { HAS_ELTS = BITS_PER_ELT_ >= 8 };\n    // The number of elements.\n    enum { NUM_ELTS = NUM_ELTS_ };\n    // The size of element in bits.\n    enum { BITS_PER_ELT = BITS_PER_ELT_ };\n    // The size of byte of a single register.\n    enum { BYTES_PER_REG = 4 };\n    // The size in bits.\n    enum { BITS_PER_REG = BYTES_PER_REG * 8 };\n    // The number of registers needed to store the fragment.\n    enum { NUM_REGS = Div_up<NUM_ELTS * BITS_PER_ELT, BITS_PER_REG>::VALUE };\n    // The size in bytes (as returned by sizeof(Fragment_base<>).\n    enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG };\n    // The alignment.\n    enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min<NUM_REGS * BYTES_PER_REG, 16>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The type of the elements.\n    typename Data_type_,\n    // The number of elements.\n    int NUM_ELTS_,\n    // The alignment if you want to force a value -- use 0 otherwise.\n    int ALIGNMENT_ = 0,\n    // The base class.\n    typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>\n>\nstruct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {\n\n    // The size of a load/store.\n    enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) };\n\n    // Clear the fragment. Using PTX in that code seems to produce better SASS...\n    inline __device__ void clear() {\n        #pragma unroll\n        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {\n            asm volatile(\"mov.u32 %0, 0; \\n\" : \"=r\"(this->reg(ii)) : );\n        }\n    }\n\n    // Immutable access to a register.\n    inline __device__ const uint32_t& reg(int ii) const {\n        return this->regs_[ii];\n    }\n\n    // Mutable access to a register.\n    inline __device__ uint32_t& reg(int ii) {\n        return this->regs_[ii];\n    }\n\n    uint32_t regs_[Base_::NUM_REGS];\n\n    // Immutable access to the elements.\n    inline __device__ const Data_type_& elt(int ii) const {\n        return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];\n    }\n\n    // Mutable access to the elements.\n    inline __device__ Data_type_& elt(int ii) {\n        return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];\n    }\n\n    // Immutable access to the elements with a cast.\n    template< typename Cast_type >\n    inline __device__ const Cast_type& elt_as(int ii) const {\n        return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];\n    }\n\n    // Mutable access to the elements.\n    template< typename Cast_type >\n    inline __device__ Cast_type& elt_as(int ii) {\n        return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];\n    }\n\n    // Add another fragment.\n    inline __device__ void add(const Fragment &other) {\n        #pragma unroll\n        for( int ii = 0; ii < NUM_ELTS_; ++ii ) {\n            this->elt(ii) += other.elt(ii);\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Layout >\nstruct Fragment_a : public Fragment<uint16_t, 8> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Layout >\nstruct Fragment_b : public Fragment<uint16_t, 8> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fragment_accumulator : public Fragment<float, 8> {\n\n    // The base class.\n    using Base = Fragment<float, 8>;\n\n    // Add two fragments.\n    template< typename Other_fragment_ >\n    inline __device__ void add(const Other_fragment_ &other) {\n        for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {\n            this->elt(ii) = this->elt(ii) + other.elt(ii);\n        }\n    }\n\n    // Do the HMMA.\n    template< typename Layout_a, typename Layout_b >\n    inline __device__ void mma(const Fragment_a<Layout_a> &a,\n                               const Fragment_b<Layout_b> &b) {\n        asm volatile( \\\n            \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \\n\" \\\n            \"    {%0, %1, %2, %3}, \\n\" \\\n            \"    {%4, %5, %6, %7}, \\n\" \\\n            \"    {%8, %9}, \\n\" \\\n            \"    {%0, %1, %2, %3}; \\n\" \\\n                    : \"+f\"(  elt(0)), \"+f\"(  elt(1)), \"+f\"(  elt(2)), \"+f\"(  elt(3))\n                    :  \"r\"(a.reg(0)),  \"r\"(a.reg(1)),  \"r\"(a.reg(2)),  \"r\"(a.reg(3))\n                    ,  \"r\"(b.reg(0)),  \"r\"(b.reg(1)));\n        asm volatile( \\\n            \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \\n\" \\\n            \"    {%0, %1, %2, %3}, \\n\" \\\n            \"    {%4, %5, %6, %7}, \\n\" \\\n            \"    {%8, %9}, \\n\" \\\n            \"    {%0, %1, %2, %3}; \\n\" \\\n                    : \"+f\"(  elt(4)), \"+f\"(  elt(5)), \"+f\"(  elt(6)), \"+f\"(  elt(7))\n                    :  \"r\"(a.reg(0)),  \"r\"(a.reg(1)),  \"r\"(a.reg(2)),  \"r\"(a.reg(3))\n                    ,  \"r\"(b.reg(2)),  \"r\"(b.reg(3)));\n    }\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Fragment, int M, int N >\ninline __device__ void clear(Fragment (&frag)[M][N]) {\n    #pragma unroll\n    for( int mi = 0; mi < M; ++mi ) {\n        #pragma unroll\n        for( int ni = 0; ni < N; ++ni ) {\n            frag[mi][ni].clear();\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Accumulator_type, int WARPS_K >\nstruct Clear_accumulator {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int WARPS_K >\nstruct Clear_accumulator<float, WARPS_K> {\n  template< typename Acc, int M, int N >\n  static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {\n    fmha::clear(acc);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Acc, typename A, typename B, int M, int N>\ninline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {\n\n    #pragma unroll\n    for( int mi = 0; mi < M; ++mi ) {\n        #pragma unroll\n        for( int ni = 0; ni < N; ++ni ) {\n            acc[mi][ni].mma(a[mi], b[ni]);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The number of rows in the CTA tile.\n    int M_,\n    // The number of cols in the CTA tile.\n    int N_,\n    // The number of elements in the the K dimension of the GEMM loop.\n    int K_,\n    // The number of rows of warps.\n    int WARPS_M_,\n    // The number of cols of warps.\n    int WARPS_N_,\n    // The number of warps in the K dimension of the GEMM loop.\n    int WARPS_K_>\nstruct Cta_tile_ {\n\n    enum { M = M_, N = N_, K = K_ };\n    // The number of warps.\n    enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ };\n    // The number of warps per CTA.\n    enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };\n    // The number of threads per warp.\n    enum { THREADS_PER_WARP = 32 };\n    // The number of threads per CTA.\n    enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile>\nstruct Hmma_tile {\n    // The number of elements computed with a single warp-MMA.\n    enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };\n\n    // The number of elements computed with a single CTA-MMA.\n    enum {\n        M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,\n        N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,\n        K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K\n    };\n\n    // The number of MMAs needed to compute the GEMM.\n    enum {\n        MMAS_M = Div_up<Cta_tile::M, M_PER_MMA_PER_CTA>::VALUE,\n        MMAS_N = Div_up<Cta_tile::N, N_PER_MMA_PER_CTA>::VALUE,\n        MMAS_K = Div_up<Cta_tile::K, K_PER_MMA_PER_CTA>::VALUE,\n    };\n\n    // The number of elements computed per warp.\n    enum {\n        M_PER_WARP = MMAS_M * M_PER_MMA,\n        N_PER_WARP = MMAS_N * N_PER_MMA,\n        K_PER_WARP = MMAS_K * K_PER_MMA,\n    };\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nusing A_type = uint16_t;\nusing B_type = uint16_t;\nusing C_type = uint16_t;\nusing Accumulator_type = float;\nusing Epilogue_type = float;\n\nconstexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;\nconstexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;\nconstexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>\nusing Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile_>\nusing Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,\n                                                   Cta_tile_::N,\n                                                   Next_power_of_two<Cta_tile_::K>::VALUE,\n                                                   Cta_tile_::WARPS_M,\n                                                   Cta_tile_::WARPS_N,\n                                                   Cta_tile_::WARPS_K>;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The number of bits per element.\n    int BITS_PER_ELEMENT,\n    // The number of rows of Q, K or V loaded by this tile.\n    int ROWS,\n    // The number of columns.\n    int COLS,\n    // The number of matrics.\n    int NUM_MATS = 3\n>\nstruct Gmem_tile_qkv {\n\n    // The size of each LDG.\n    enum { BYTES_PER_LDG = 16 };\n    // The size of a row in bytes.\n    enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 };\n\n    // The number of threads to load a \"row\" of the matrix.\n    enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG };\n\n    // The number of \"rows\" loaded per LDG.\n    enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // The number of LDGs needed to load a chunk of the Q matrix.\n    enum { LDGS = fmha::Div_up<ROWS, ROWS_PER_LDG>::VALUE };\n\n    // Ctor.\n    template< typename Params, typename BInfo >\n    inline __device__ Gmem_tile_qkv(const Params &params, int qkv_offset, const BInfo &binfo, int tidx)\n        : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)\n        , actual_seqlen(binfo.actual_seqlen)\n        , qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {\n\n        // Compute the position in the sequence (within the CTA for the moment).\n        int row = tidx / THREADS_PER_ROW;\n        // Compute the position of the thread in the row.\n        int col = tidx % THREADS_PER_ROW;\n\n        // Store the row as we need it to disable the loads.\n        row_ = row;\n\n        // The row offset in the batched GEMM. For each seq element, we store QKV in that order.\n        int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;\n        // Add the block index.\n        row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;\n\n        // Assemble the final pointer.\n        qkv_ptr_ += row_offset + col * BYTES_PER_LDG;\n    }\n\n    // Store data to shared memory.\n    template< typename Smem_tile >\n    inline __device__ void commit(Smem_tile &smem_tile) {\n        smem_tile.store(fetch_);\n    }\n\n    // Load data from memory.\n    template< typename Smem_tile >\n    inline __device__ void load(Smem_tile &smem_tile) {\n        const void *ptrs[LDGS];\n        uint32_t preds[LDGS];\n        #pragma unroll\n        for( int ii = 0; ii < LDGS; ++ii ) {\n            ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;\n            preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));\n            fetch_[ii] = make_uint4(0, 0, 0, 0);\n        }\n\n        // not packing predicates removes restrictions (e.g. FP16 384, 4 warps)\n        Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);\n        #pragma unroll\n        for( int ii = 0; ii < LDGS; ++ii ) {\n            fct.load(ii, preds[ii]);\n        }\n    }\n\n    // Store data to memory.\n    inline __device__ void store(const uint4 (&data)[LDGS]) {\n        #pragma unroll\n        for( int ii = 0; ii < LDGS; ++ii ) {\n            char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;\n            if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {\n                fmha::stg(ptr, data[ii]);\n            }\n        }\n    }\n\n    // Move the pointer to the next location.\n    inline __device__ void move() {\n        qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;\n        actual_seqlen -= ROWS;\n    }\n\n    // The stride between rows for the QKV matrice.\n    int64_t params_qkv_stride_in_bytes_;\n    // The pointer.\n    char *qkv_ptr_;\n    // The fetch registers.\n    uint4 fetch_[LDGS];\n    // Keep track of the row the thread is processing as we move the tile.\n    int row_;\n    // The length of the sequence loaded by that memory tile.\n    int actual_seqlen;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile >\nstruct Gmem_tile_o {\n\n    // The mma tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    // The size of each element.\n    enum { BYTES_PER_ELEMENT = 2 };\n    // The size of a row in bytes.\n    enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT };\n\n    // The number of threads to store a \"row\" of the matrix.\n    enum { THREADS_PER_ROW = 16 };\n    // The size of each STG.\n    enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW };\n\n    // The number of \"rows\" stored per iteration of the loop. The output of 1 MMA.\n    enum { ROWS = Cta_tile::M };\n    // The number of \"rows\" stored per iteration of the loop. The output of 1 MMA.\n    enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };\n    // The number of outter loop for the stores.\n    enum { LOOPS = ROWS / ROWS_PER_LOOP };\n\n    // The number of \"rows\" stored per STG.\n    enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // Do we have to guard against partial writes/reads.\n    enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 };\n    // The number of STGs needed to store a chunk of the Q matrix.\n    enum { STGS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_STG>::VALUE };\n    // The number of STGs needed to store a chunk of the Q matrix in total.\n    enum { STGS = STGS_PER_LOOP * LOOPS };\n\n    // Ctor.\n    template<typename Params, typename BInfo>\n    inline __device__ Gmem_tile_o(const Params &params, const BInfo &binfo, int tidx)\n        : params_o_stride_in_bytes_(params.o_stride_in_bytes)\n        , actual_seqlen_(binfo.actual_seqlen)\n        , o_ptr_(reinterpret_cast<char *>(params.o_ptr)) {\n\n        // Compute the position in the sequence (within the CTA for the moment).\n        int row = tidx / THREADS_PER_ROW;\n        // Compute the position of the thread in the row.\n        int col = tidx % THREADS_PER_ROW;\n\n        // Store the row as we need it to disable loads.\n        row_ = row;\n\n        // The row offset in the batched GEMM.\n        int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;\n        // Assemble the final pointer.\n        o_ptr_ += row_offset + col * BYTES_PER_STG;\n\n        // Is that thread active on the last STG?\n        if( HAS_INCOMPLETE_STG ) {\n            is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;\n        }\n    }\n\n    // Store data to global memory.\n    inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {\n\n        #pragma unroll\n        for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {\n            int jj = mi * STGS_PER_LOOP + ii;\n            if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {\n                break;\n            }\n\n            float x = reinterpret_cast<const float &>(src[ii].x);\n            float y = reinterpret_cast<const float &>(src[ii].y);\n            float z = reinterpret_cast<const float &>(src[ii].z);\n            float w = reinterpret_cast<const float &>(src[ii].w);\n            uint2 out = float4_to_half4(x, y, z, w);\n            if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {\n                fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out);\n            }\n        }\n    }\n\n    // Move the pointer to the next location.\n    inline __device__ void move() {\n        row_ += ROWS;\n        o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;\n    }\n\n    // The stride between rows for the QKV matrice.\n    int64_t params_o_stride_in_bytes_;\n    // The pointer.\n    char *o_ptr_;\n    // Is the thread active for the last STG?\n    int is_active_for_last_stg_;\n    // Keep track of the row to disable loads.\n    int row_;\n    // The length of the sequence loaded by that memory tile.\n    int actual_seqlen_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile, int BYTES_PER_ELEMENT >\nstruct Gmem_tile_mma_sd {\n\n    // The mma tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    // Each STG stores 8 elements.\n    enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };\n    // The number of MMAs in the M dimension.\n    enum { MMAS_M = Mma_tile::MMAS_M };\n    // The number of MMAs in the N dimension.\n    enum { MMAS_N = Mma_tile::MMAS_N };\n    // The number of rows computed per MMA per thread block.\n    enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA };\n    // The number of cols computed per MMA per thread block.\n    enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA };\n    // The number of threads per block.\n    enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA };\n    // The size of each row in bytes. I.e. how many bytes are stored per STG.\n    enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG };\n    // The fixed sequence length.\n    enum { SEQLEN = Cta_tile::N };\n    // The distance between two blocks (in bytes).\n    enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT };\n    // The distance between elements stored per loop (in bytes).\n    enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW };\n\n    // The type of elements stored per STG.\n    using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;\n\n    // Ctor.\n    template<typename Params>\n    inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int tidx) \n        : ptr_(static_cast<char *>(ptr)) {\n\n        // The block index for the batch.\n        const int bidb = blockIdx.y;\n        // The block index for the head.\n        const int bidh = blockIdx.x;\n        // The block index.\n        size_t bidx = bidb * params.h + bidh;\n\n        // Set store location for each thread at the beginning of the loop\n        ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG;\n    }\n\n    // Store to global memory.\n    inline __device__ void store(const Type &data, const int mi, const int ni) {\n        size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;\n        fmha::stg(ptr_ + offset, data);\n    }\n\n    // Load from global memory.\n    inline __device__ void load(Type &data, const int mi, const int ni) {\n        size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;\n        fmha::ldg(data, ptr_ + offset);\n    }\n\n    // Move to the next tile.\n    inline __device__ void move() {\n        ptr_ += LOOP_STRIDE_BYTES;\n    }\n\n    // The pointer in global memory.\n    char *ptr_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >\nstruct Gmem_tile_mma_s : public Base {\n\n    // The number of mmas in the vertical dimension.\n    enum { M = Base::MMAS_M };\n    // The number of mmas in the horizontal dimension.\n    enum { N = Base::MMAS_N };\n    // The type of the vectors stored by each STG.\n    using Type = typename Base::Type;\n\n    // Ctor.\n    template< typename Params >\n    inline __device__ Gmem_tile_mma_s(void *ptr, const Params &params, const int tidx) \n        : Base(ptr, params, tidx) {\n    }\n\n    // Store to global memory.\n    template<typename Mask>\n    inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n\n                float tmp00 = softmax[2 * mi + 0][4 * ni + 0];\n                float tmp01 = softmax[2 * mi + 0][4 * ni + 1];\n                float tmp02 = softmax[2 * mi + 0][4 * ni + 2];\n                float tmp03 = softmax[2 * mi + 0][4 * ni + 3];\n\n                float tmp10 = softmax[2 * mi + 1][4 * ni + 0];\n                float tmp11 = softmax[2 * mi + 1][4 * ni + 1];\n                float tmp12 = softmax[2 * mi + 1][4 * ni + 2];\n                float tmp13 = softmax[2 * mi + 1][4 * ni + 3];\n\n                uint4 dst;\n                dst.x = fmha::float2_to_half2(tmp00, tmp01);\n                dst.y = fmha::float2_to_half2(tmp02, tmp03);\n                dst.z = fmha::float2_to_half2(tmp10, tmp11);\n                dst.w = fmha::float2_to_half2(tmp12, tmp13);\n                if( mask.is_valid(mi, ni, 0, 0) ) {\n                    Base::store(dst, mi, ni);\n                }\n            }\n        }\n    }\n\n    // Load from global memory.\n    template<typename Mask>\n    inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                regs[mi][ni] = make_uint4(0, 0, 0, 0);\n                if( mask.is_valid(mi, ni, 0, 0) ) {\n                    Base::load(regs[mi][ni], mi, ni);\n                }\n            }\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The base class.\n    typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>\n>\nstruct Gmem_tile_dout : public Base {\n\n    // Ctor.\n    template<typename Params, typename BInfo>\n    inline __device__ Gmem_tile_dout(const Params &params, const BInfo &binfo, int tidx)\n        : Base(params, 0, binfo, tidx) {\n\n        this->qkv_ptr_ = reinterpret_cast<char *>(params.o_ptr);\n        this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes;  // needed for move\n\n        // Compute the position of the thread in the row.\n        int col = tidx % Base::THREADS_PER_ROW;\n\n        // The row offset in the batched GEMM. For each seq element, we store O in that order.\n        int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;\n\n        // Assemble the final pointer.\n        this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >\nstruct Gmem_tile_dq : public Base {\n\n    // Ctor.\n    template<typename Params, typename BInfo>\n    inline __device__ Gmem_tile_dq(const Params &params, const BInfo &binfo, int tidx) \n        : Base(params, binfo, tidx) {\n        this->o_ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);\n        this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes;  // needed for move\n\n        // Compute the position of the thread in the row.\n        int col = tidx % Base::THREADS_PER_ROW;\n\n        // The row offset in the batched GEMM. For each seq element, we store O in that order.\n        int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +\n                             (binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;\n\n        // Assemble the final pointer.\n        this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x8u>\nstruct FMHA_kernel_traits {\n\n    // The CTA description for the 1st GEMM.\n    using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;\n    // The CTA description for the 2nd GEMM.\n    using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;\n\n    // Do we use one buffer for K and V.\n    enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u };\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;\n\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;\n    // The shared memory tile for O.\n    using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;\n\n    // The global memory tile to load/store S.\n    using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;\n\n    // The shared memory tile to transpose S.\n    using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;\n\n    using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>;\n\n    // Make sure the number of threads match.\n    static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, \"\");\n\n    // The number of threads.\n    enum { THREADS = Cta_tile_p::THREADS_PER_CTA };\n    // Make sure the number of threads matches both CTAs.\n    static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, \"\");\n\n    // The amount of shared memory needed to load Q and K.\n    enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE };\n    // The extra amount of shared memory needed to load V.\n    enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE };\n    // The amount of shared memory needed for Q, K and V..\n    enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V };\n    // The amount of shared memory needed to load Q and store O.\n    enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE };\n\n    // The amount of shared memory needed for Q, K, V and O.\n    enum { BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE };\n    // Make sure we have enough shared memory.\n    static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, \"\");\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha/mask.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n\ntemplate<typename Cta_tile>\nstruct Mask {\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    template<typename Params, typename BInfo>\n    __device__ Mask(const Params &params, const BInfo &blockInfo, int tidx) {\n\n        actual_seqlen = blockInfo.actual_seqlen;\n\n        const int warp = tidx / Cta_tile::THREADS_PER_WARP;\n        const int lane = tidx % Cta_tile::THREADS_PER_WARP;\n\n        static_assert(Cta_tile::WARPS_K == 1, \"\");\n\n        // find the warp in the Cta tile\n        const int warp_n = (warp / Cta_tile::WARPS_M);\n        const int warp_m = (warp % Cta_tile::WARPS_M);\n        // decompose warp into 8x4 tile\n        const int quad = lane / 4;\n        const int tid = (lane % 4) * 2;\n        row = warp_m * 16 + quad;\n        col = warp_n * 16 + tid;\n    }\n\n    inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {\n\n        // ii and jj iterate over the 2x4 fragment\n        const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen;\n        //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;\n        return col_valid;\n        // return row_valid && col_valid;\n    }\n\n    inline __device__ void load(int it) {\n        row_offset = it * Cta_tile::M + row;\n    }\n    int row_offset;\n\n    int row;\n    int col;\n    int actual_seqlen;\n};\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha/smem_tile.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/utils.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< \n    // The description of the tile computed by this CTA.\n    typename Cta_tile, \n    // The number of rows in the 2D shared memory buffer.\n    int M_, \n    // The number of cols.\n    int N_, \n    // The size in bits of each element.\n    int BITS_PER_ELEMENT_, \n    // The number of bytes per STS.\n    int BYTES_PER_STS_ = 16,\n    // The number of buffers. (Used in multistage and double buffer cases.)\n    int BUFFERS_PER_TILE_ = 1,\n    // Do we enable the fast path for LDS.128 and friends.\n    int ENABLE_LDS_FAST_PATH_ = 0, \n    // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. \n    int ROWS_PER_XOR_PATTERN_ = 8,\n    // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. \n    int COLS_PER_XOR_PATTERN_ = 1,\n    // Use or not predicates\n    bool USE_PREDICATES_ = true\n>\nstruct Smem_tile_without_skews {\n\n    // The size in bits of each element.\n    enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };\n    // The size in bytes of a single STS.\n    enum { BYTES_PER_STS = BYTES_PER_STS_ };\n    // The number of elements per STS.\n    enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };\n    // To support arbitrary N, we pad some values to a power-of-2.\n    enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE }; \n    // The number of bytes per row without packing of rows.\n    enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };\n    // The number of bytes per row -- we want at least 128B per row.\n    enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };\n    // The number of rows in shared memory (two rows may be packed into a single one).\n    enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };\n\n    // The number of threads per row.\n    enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };\n    // The number of threads per row.\n    enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };\n\n    // The number of STS per row.\n    enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };\n    // It must be at least one.\n    static_assert(STS_PER_ROW >= 1, \"\");\n    // The number of rows written with a single STS.\n    enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)\n    static_assert(ROWS_PER_STS >= 1, \"\");\n    // The number of STS needed to store all rows.\n    enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };\n    // The number of STS in total.\n    enum { STS = STS_PER_COL * STS_PER_ROW };\n\n    // The size of one buffer in bytes in shared memory.\n    enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };\n    // The number of buffers. \n    enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };\n    // The size in bytes of total buffers.\n    enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };\n    // The boundary for smem_read_offset and smem_write_offset increment.\n    enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };\n\n    // Do we enable the LDS.128 fast path?\n    enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };\n    static_assert(ENABLE_LDS_FAST_PATH == 0);\n    // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. \n    enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };\n    // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. \n    enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };\n    // Use or not predicates\n    enum { USE_PREDICATES = USE_PREDICATES_ };\n\n    // The type of elements that are stored in shared memory by each thread.\n    using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;\n\n    // Ctor.\n    inline __device__ Smem_tile_without_skews(void *smem, int tidx) \n        : smem_(__nvvm_get_smem_pointer(smem)) {\n\n        // The row written by a thread. See doc/mma_smem_layout.xlsx.\n        int smem_write_row = tidx / THREADS_PER_ROW;\n\n        // The XOR pattern.\n        int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;\n        // Compute the column and apply the XOR pattern.\n        int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;\n\n        // The offset.\n        this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;\n\n        // TODO: Why not merge it with the read offset?\n        this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);\n        this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);\n    }\n\n    // Compute the store pointers.\n    template< int N >\n    inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {\n        #pragma unroll\n        for( int ii = 0; ii < N; ++ii ) {\n            // Decompose the STS into row/col.\n            int row = ii / STS_PER_ROW;\n            int col = ii % STS_PER_ROW;\n\n            // Assemble the offset.\n            int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;\n\n            // Take the column into account.\n            if( STS_PER_ROW > 1 ) {\n                offset += col*THREADS_PER_ROW*BYTES_PER_STS; \n            }\n\n            // Apply the XOR pattern if needed.\n            if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {\n                const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;\n                offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;\n            }\n\n            // Assemble the final pointer :)\n            ptrs[ii] = smem_ + offset + smem_write_buffer_;\n        }\n    }\n\n    inline __device__ void debug_reset() {\n        for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {\n        for( int row = 0; row < ROWS; ++row ) {\n            for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {\n                if( threadIdx.x == 0 ) {\n                    uint32_t val = 0x0;\n                    sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);\n                }\n            }\n        }\n        }\n    }\n\n    // Print the content of the tile (only for debug ;)).\n    inline __device__ void debug_print() const {\n        for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {\n        for( int row = 0; row < ROWS; ++row ) {\n            for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {\n                if( threadIdx.x == 0 ) {\n                    uint32_t val;\n                    lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);\n                    printf(\"block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\\n\",\n                        blockIdx.x,\n                        blockIdx.y,\n                        blockIdx.z,\n                        smem_,\n                        buffer,\n                        row,\n                        col,\n                        val);\n                }\n            }\n        }\n        }\n    }\n\n    // Move the read offset to next buffer.\n    inline __device__ void move_to_next_read_buffer() {\n        if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {\n            this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;\n        } else if( BUFFERS_PER_TILE > 1 ) {\n            this->smem_read_buffer_ += BYTES_PER_BUFFER;\n        }\n    }\n\n    // Move the read offset to next buffer. TODO: Remove this member function!!!\n    inline __device__ void move_next_read_buffer() {\n        this->move_to_next_read_buffer();\n    }\n\n    // Move the read offset to next N buffer (circular-buffer).\n    inline __device__ void move_to_next_read_buffer(int N) {\n        if( BUFFERS_PER_TILE > 1 ) {\n            this->smem_read_buffer_ += N * BYTES_PER_BUFFER;\n            this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;\n        }\n    }\n\n    // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!\n    inline __device__ void move_next_read_buffer(int N) {\n        this->move_to_next_read_buffer(N);\n    }\n\n    // Move the write offset to next buffer.\n    inline __device__ void move_to_next_write_buffer() {\n        if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {\n            this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;\n        } else if( BUFFERS_PER_TILE > 1 ) {\n            this->smem_write_buffer_ += BYTES_PER_BUFFER;\n        }\n    }\n\n    // Move the write offset to next buffer. TODO: Remove that member function!\n    inline __device__ void move_next_write_buffer() {\n        this->move_to_next_write_buffer();\n    }\n\n    // Move the read offset.\n    inline __device__ void move_read_offset(int delta) {\n        this->smem_read_offset_ += delta;\n    }\n\n    // Move the write offset.\n    inline __device__ void move_write_offset(int delta) {\n        this->smem_write_offset_ += delta;\n    }\n\n    // Store to the tile in shared memory.\n    template< int N >\n    inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {\n        uint32_t smem_ptrs[N];\n        this->compute_store_pointers(smem_ptrs);\n        sts(smem_ptrs, data);\n    }\n\n    // Store to the tile in shared memory.\n    template< int N, int M >\n    inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {\n        uint32_t smem_ptrs[N];\n        this->compute_store_pointers(smem_ptrs);\n        sts(smem_ptrs, data, preds);\n    }\n\n    // Store to the tile in shared memory.\n    template< int N >\n    inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { \n        this->store(data, preds);\n    }\n\n    // Store to the tile in shared memory.\n    template< int N >\n    inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {\n        uint32_t tmp[1] = { preds };\n        this->store(gmem_ptrs, tmp);\n    }\n\n    // The shared memory pointer.\n    uint32_t smem_;\n    // The read offset. Reserve 4 offsets if needed.\n    int smem_read_offset_;\n    // The write offset.\n    int smem_write_offset_;\n    // The buffer base offset for read.\n    int smem_read_buffer_;\n    // The buffer base offset for write.\n    int smem_write_buffer_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< \n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile, \n    // The layout of the tile.\n    typename Layout, \n    // The size of the STS.\n    int BYTES_PER_STS = 16,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE = 1,\n    // Use or not predicates\n    bool USE_PREDICATES = true\n>\nstruct Smem_tile_a {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int MMAS_K, int MMAS_K_WITH_PADDING >\nstruct Compute_reset_mask {\n    // The potential mask.\n    enum { HALF = MMAS_K_WITH_PADDING / 2 };\n    // The remainder.\n    enum { MOD = MMAS_K % HALF };\n    // The final value.\n    enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int MMAS_K_WITH_PADDING >\nstruct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {\n    enum { VALUE = 0 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int MMAS_K >\nstruct Compute_reset_mask<MMAS_K, MMAS_K> {\n    enum { VALUE = MMAS_K - 1 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_a {\n    // The size in bits.\n    enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };\n    // The number of rows.\n    enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE\n>\nstruct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,\n                                                               Cta_tile::M,\n                                                               Cta_tile::K,\n                                                               fmha::BITS_PER_ELEMENT_A,\n                                                               BYTES_PER_STS,\n                                                               BUFFERS_PER_TILE,\n                                                               0,\n                                                               ROWS_PER_XOR_PATTERN_,\n                                                               1> {\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile,\n                                         Cta_tile::M,\n                                         Cta_tile::K,\n                                         fmha::BITS_PER_ELEMENT_A,\n                                         BYTES_PER_STS,\n                                         BUFFERS_PER_TILE,\n                                         0,\n                                         ROWS_PER_XOR_PATTERN_,\n                                         1>;\n    // The fragment.\n    using Fragment = Fragment_a<Row>;\n\n    // When we use padding to reach a power of two, special care has to be taken.\n    using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;\n    // The number of MMAs.\n    using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;\n\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = 16 };\n\n    // Ctor.\n    inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {\n\n        // For documentation on the layout, see doc/mma_smem_layout.xlsx.\n\n        // The number of warps.\n        const int WARPS_M = Cta_tile::WARPS_M;\n        const int WARPS_N = Cta_tile::WARPS_N;\n        const int WARPS_K = Cta_tile::WARPS_K;\n\n        static_assert(WARPS_M == 1);\n        static_assert(WARPS_N == 4 || WARPS_N == 8);\n        static_assert(WARPS_K == 1);\n        static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n\n        // The row and column read by the thread.\n        int smem_read_row  = (tidx & 0x0f);\n        int smem_read_col  = (tidx & 0x07);\n        smem_read_col ^= (tidx & 0x10) / 16;\n\n        // The shared memory offset.\n        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;\n    }\n\n    // Rewind smem_read_offset for last LDS phase in main loop.\n    inline __device__ void reverse_smem_read_offset(int ki = 0) {\n        // Undo the pointer increment for the next ni.\n        // Should match the load function below for ki = 0.\n        if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {\n        #pragma unroll\n        for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {\n            // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).\n            int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;\n\n            // Load using LDSM.M88.4.\n            uint4 tmp;\n            ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);\n\n            // Store the value into the fragment.\n            a[mi].reg(0) = tmp.x;\n            a[mi].reg(1) = tmp.y;\n            a[mi].reg(2) = tmp.z;\n            a[mi].reg(3) = tmp.w;\n        }\n\n        // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.\n        static_assert(Mma_tile_with_padding::MMAS_K < 64, \"Not implemented\");\n        if(        Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {\n            this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki %  8 ==  7 ) {\n            this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  8 && ki %  4 ==  3 ) {\n            this->smem_read_offset_ ^=  7 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  4 && ki %  2 ==  1 ) {\n            this->smem_read_offset_ ^=  3 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^=  1 * BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Reset the read offset.\n    inline __device__ void reset_read_offset() {\n        // The number of MMAs in the K dimension.\n        enum { MMAS_K = Mma_tile::MMAS_K };\n        // The number of MMAs in the K dimension when we include padding.\n        enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };\n        // Assemble the mask.\n        enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };\n\n        // Reset the read offset.\n        this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;\n    }\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE\n>\nstruct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>\n    : public Smem_tile_row_a<Cta_tile,\n                                    BYTES_PER_STS,\n                                    BUFFERS_PER_TILE> {\n    // The base class.\n    using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n    // Ctor.\n    inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< \n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile, \n    // The layout of the tile.\n    typename Layout, \n    // The size of the STS.\n    int BYTES_PER_STS = 16,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE = 1,\n    // Use or not predicates\n    bool USE_PREDICATES = true\n>\nstruct Smem_tile_b {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_b {\n    // The size in bits.\n    enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };\n    // The number of rows.\n    enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE\n>\nstruct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,\n                                                           Cta_tile::N,\n                                                           Cta_tile::K,\n                                                           fmha::BITS_PER_ELEMENT_B,\n                                                           BYTES_PER_STS,\n                                                           BUFFERS_PER_TILE,\n                                                           0,\n                                                           ROWS_PER_XOR_PATTERN_,\n                                                           1> {\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile,\n                                         Cta_tile::N,\n                                         Cta_tile::K,\n                                         fmha::BITS_PER_ELEMENT_B,\n                                         BYTES_PER_STS,\n                                         BUFFERS_PER_TILE,\n                                         0,\n                                         ROWS_PER_XOR_PATTERN_,\n                                         1>;\n    // The fragment.\n    using Fragment = Fragment_b< Col>;\n\n    // When we use padding to reach a power of two, special care has to be taken.\n    using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;\n    // The number of MMAs.\n    using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;\n\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = 16 };\n\n    // The number of STS per thread\n    enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };\n    // The number of STS per thread must be at least 1.\n    enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };\n\n    // Ctor.\n    inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {\n\n        // For documentation on the layout, see doc/mma_smem_layout.xlsx.\n\n        // The number of warps.\n        const int WARPS_M = Cta_tile::WARPS_M;\n        const int WARPS_N = Cta_tile::WARPS_N;\n        const int WARPS_K = Cta_tile::WARPS_K;\n        static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n        static_assert(WARPS_M == 1);\n        static_assert(WARPS_N == 4 || WARPS_N == 8);\n        static_assert(WARPS_K == 1);\n\n        // The masks to select the warps.\n        const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;\n\n        // The divisor for the warps.\n        const int WARP_DIV_N = WARPS_M *       1 * Cta_tile::THREADS_PER_WARP;\n\n        // The row and column read by the thread.\n        int smem_read_row  = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +\n                             (tidx & 0x07) +\n                             (tidx & 0x10) / 2;\n        int smem_read_col  = (tidx & 0x07);\n        smem_read_col ^= (tidx & 0x08) / 8;\n        // The shared memory offset.\n        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;\n    }\n\n    // Rewind smem_read_offset for last LDS phase in main loop.\n    inline __device__ void reverse_smem_read_offset(int ki = 0) {\n        // Undo the pointer increment for the next ni.\n        // Should match the load function below for ki = 0.\n        if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).\n            int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;\n\n            // Load using LDSM.M88.4.\n            uint4 tmp;\n            ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);\n\n            // Store the value into the fragment.\n            b[ni].reg(0) = tmp.x;\n            b[ni].reg(1) = tmp.y;\n            b[ni].reg(2) = tmp.z;\n            b[ni].reg(3) = tmp.w;\n        }\n\n        // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.\n        static_assert(Mma_tile_with_padding::MMAS_K < 64, \"Not implemented\");\n        if(        Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {\n            this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki %  8 ==  7 ) {\n            this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  8 && ki %  4 ==  3 ) {\n            this->smem_read_offset_ ^=  7 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  4 && ki %  2 ==  1 ) {\n            this->smem_read_offset_ ^=  3 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^=  1 * BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Reset the read offset.\n    inline __device__ void reset_read_offset() {\n        // The number of MMAs in the K dimension.\n        enum { MMAS_K = Mma_tile::MMAS_K };\n        // The number of MMAs in the K dimension when we include padding.\n        enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };\n        // Assemble the mask.\n        enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };\n\n        // Reset the read offset.\n        this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE\n>\nstruct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >\n    : public Smem_tile_col_b<Cta_tile,\n                             BYTES_PER_STS,\n                             BUFFERS_PER_TILE> {\n\n    // The base class.\n    using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n    // Ctor.\n    inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<  int N >\nstruct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,\n    // How many cols to use for the XOR pattern to avoid bank conflicts?\n    int COLS_PER_XOR_PATTERN_ = 1\n>\nstruct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,\n                                                               Cta_tile::K,\n                                                               Cta_tile::N,\n                                                               fmha::BITS_PER_ELEMENT_B,\n                                                               BYTES_PER_STS,\n                                                               BUFFERS_PER_TILE,\n                                                               0,\n                                                               ROWS_PER_XOR_PATTERN_,\n                                                               COLS_PER_XOR_PATTERN_> {\n\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile,\n                                         Cta_tile::K,\n                                         Cta_tile::N,\n                                         fmha::BITS_PER_ELEMENT_B,\n                                         BYTES_PER_STS,\n                                         BUFFERS_PER_TILE,\n                                         0,\n                                         ROWS_PER_XOR_PATTERN_,\n                                         COLS_PER_XOR_PATTERN_>;\n    // The fragment.\n    using Fragment = Fragment_b<Row>;\n\n    // Can we use LDSM? No if the data type is 32-bit large.\n    enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };\n    // The number of elements per LDS.\n    enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };\n\n    // The number of STS per thread\n    enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };\n    // The number of STS per thread must be at least 1.\n    enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };\n\n    // Ctor.\n    inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {\n\n        // The number of warps.\n        const int WARPS_M = Cta_tile::WARPS_M;\n        const int WARPS_N = Cta_tile::WARPS_N;\n        const int WARPS_K = Cta_tile::WARPS_K;\n        static_assert(WARPS_K == 1);\n        static_assert(WARPS_M == 4 || WARPS_M == 8);\n        static_assert(WARPS_N == 1);\n\n        // The masks to select the warps.\n        const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;\n        const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;\n\n        // The divisor for the warps.\n        const int WARP_DIV_N = WARPS_M *       1 * Cta_tile::THREADS_PER_WARP;\n        const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;\n\n        // The row/col read by the thread.\n        int smem_read_row, smem_read_col;\n\n        static_assert(USE_LDSMT);\n        static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n\n        smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +\n                        (tidx & 0x07) + (tidx & 0x08);\n        smem_read_col = (tidx & 0x07);\n        smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;\n\n        // The shared memory offset.\n        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;\n\n        // Fill zeroes for group conv\n    }\n\n    // Rewind smem_read_offset for last LDS phase in main loop.\n    inline __device__ void reverse_smem_read_offset(int ki = 0) {\n        // The size of each element in bits.\n        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;\n        // The size in bytes of the data needed to compute an MMA per CTA.\n        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;\n\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Undo the pointer increment for the next ni.\n            // Should match the load function below for ki = 0.\n            if( BYTES_PER_MMA_PER_CTA >= 128 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {\n                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n            }\n        }\n\n        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)\n        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&\n                Mma_tile::MMAS_N % 2 == 1 ) {\n            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n        }\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n        // The size of each element in bits.\n        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;\n        // The size in bytes of the data needed to compute an MMA per CTA.\n        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;\n\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Prepare the offset.\n            int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW;\n                if ( BYTES_PER_MMA_PER_CTA == 32 ) {\n                    offset += this->smem_read_offset_;\n                } else if ( BYTES_PER_MMA_PER_CTA == 64 ) {\n                    offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;\n                } else {\n                    offset += this->smem_read_offset_ + (ni  ) * BYTES_PER_MMA_PER_CTA;\n                }\n\n            // Load the data using LDSM.MT88.2.\n            uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;\n            uint4 tmp;\n            if( USE_LDSMT ) {\n                ldsmt(tmp, ptr);\n            } else {\n                lds(tmp.x, (ptr     ) + 0*Base::BYTES_PER_ROW);\n                lds(tmp.y, (ptr     ) + 4*Base::BYTES_PER_ROW);\n                lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW);\n                lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW);\n            }\n\n            // Store those values in the fragment.\n            b[ni].reg(0) = tmp.x;\n            b[ni].reg(1) = tmp.y;\n            b[ni].reg(2) = tmp.z;\n            b[ni].reg(3) = tmp.w;\n\n            // Move the pointer for the next ni. I expect the compiler to not recompute those.\n            if( BYTES_PER_MMA_PER_CTA >= 128 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {\n                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n            }\n        }\n\n        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)\n        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&\n                Mma_tile::MMAS_N % 2 == 1 ) {\n            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE\n>\nstruct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>\n    : public Smem_tile_row_b<Cta_tile,\n                             BYTES_PER_STS,\n                             BUFFERS_PER_TILE> {\n\n    // The base class.\n    using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n    // Ctor.\n    inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile>\nstruct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1> {\n\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1>;\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The fragment.\n    using Fragment = Fragment_b< fmha::Col>;\n\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = 16 };\n\n    // Ctor.\n    inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {\n\n        // The row/col read by the thread.\n        int read_row, read_col;\n\n        static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));\n\n        read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);\n        read_col = (tidx & 0x07);\n        read_col ^= (tidx & 0x10) / 16;\n\n        // The shared memory offset.\n        this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n#pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Jump by 16 * #warps row.\n            int row = ki * 16 * Cta_tile::WARPS_K;\n\n            // Load the data using LDSM.MT88.2.\n            uint4 tmp;\n            fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW);\n            b[ni].reg(0) = tmp.x;\n            b[ni].reg(1) = tmp.y;\n            b[ni].reg(2) = tmp.z;\n            b[ni].reg(3) = tmp.w;\n\n            // Move the pointer for the next ni. I expect the compiler to not recompute those.\n            if( Mma_tile::MMAS_N == 4 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n            } else {\n                assert(false);  // Not implemented!\n            }\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile>\nstruct Smem_tile_o {\n\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The accumulators.\n    using Accumulator = fmha::Fragment_accumulator;\n    // The accumulators.\n    using Data_type = typename Accumulator::Data_type;\n\n    // The size of each element.\n    enum { BYTES_PER_ELEMENT = sizeof(Data_type) };\n    // The size of each STS.\n    enum { BYTES_PER_STS = 8 };\n    // The size of each row in shared memory.\n    enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT };\n\n    // The size of each LDS.\n    enum { BYTES_PER_LDS = 16 };\n    enum { THREADS_PER_ROW = 16 };\n\n    // The number of rows.\n    enum { ROWS = Cta_tile::M };\n    // The number of \"rows\" to process per loop iteration (in the \"epilogue\").\n    enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };\n    // The number of outer loops.\n    enum { LOOPS = ROWS / ROWS_PER_LOOP };\n    // Make sure it matches our expectations.\n    static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, \"\");\n\n    // The number of rows loaded per LDS.\n    enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // Do we have to guard against partial writes/reads.\n    enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 };\n    // The total number of LDS per loop.\n    enum { LDS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_LDS>::VALUE };\n\n    // The amount of shared memory.\n    enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW };\n\n    // The write pointer.\n    uint32_t smem_write_, smem_read_;\n    // Is the thread active for the last LDS of the series?\n    int is_active_for_last_lds_;\n\n    static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);\n    static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, \"\");\n\n    // Ctor.\n    inline __device__ Smem_tile_o(void *smem, int tidx) {\n\n        // Get a 32-bit value for the shared memory address.\n        uint32_t smem_ = __nvvm_get_smem_pointer(smem);\n\n        static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));\n\n        int write_row = (tidx & 0x1c) / 4;\n        int write_col = (tidx);\n\n        // Assemble the write pointer.\n        smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;\n\n        // The element read by each thread.\n        int read_row = tidx / THREADS_PER_ROW;\n        int read_col = tidx % THREADS_PER_ROW;\n\n        // Take the XOR pattern into account for the column.\n        read_col ^= 2 * (read_row & 0x7);\n\n        // Assemble the read pointer.\n        this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n\n        // Is that thread active on the last LDS?\n        if( HAS_INCOMPLETE_LDS ) {\n            this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;\n        }\n    }\n\n    // Load the output fragments.\n    inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {\n        #pragma unroll\n        for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {\n\n            // Load the elements before the reduction (split-K).\n            uint4 tmp[Cta_tile::WARPS_K];\n            #pragma unroll\n            for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {\n                int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;\n                if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {\n                    fmha::lds(tmp[jj], this->smem_read_ + imm);\n                }\n            }\n\n            // Perform the reduction.\n            out[ii] = tmp[0];\n            #pragma unroll\n            for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {\n                out[ii] = fmha::fadd4(out[ii], tmp[jj]);\n            }\n        }\n    }\n    // Store the accumulators.\n    template <int M, int N>\n    inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {\n        enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n\n            // The number of MMAs that are stored per loop iteration.\n            enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS };\n\n            // Store 1st column of the different MMAs.\n            #pragma unroll\n            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {\n                // Precompute the immediates to jump between rows.\n                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;\n                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;\n                uint2 tmp0, tmp1;\n                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);\n                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);\n\n                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);\n                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);\n\n                // Store.\n                fmha::sts(this->smem_write_ + row_0, tmp0);\n                fmha::sts(this->smem_write_ + row_1, tmp1);\n            }\n\n            // Swizzle the write pointer using a XOR of 16B.\n            this->smem_write_ ^= 32;\n\n            // Store 2nd column of the different MMAs.\n            #pragma unroll\n            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {\n                // Precompute the immediates to jump between rows.\n                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;\n                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;\n\n                uint2 tmp0, tmp1;\n                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);\n                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);\n\n                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);\n                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);\n                // Store.\n                fmha::sts(this->smem_write_ + row_0, tmp0);\n                fmha::sts(this->smem_write_ + row_1, tmp1);\n            }\n\n            // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.\n            this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile>\nstruct Smem_tile_mma {\n\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    using Fragment = fmha::Fragment_a<fmha::Col>;\n\n    enum { COLS = Cta_tile::N };\n    enum { BYTES_PER_ELT = 2 };\n    enum { BYTES_PER_STS = 4 };\n    enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT };  // TODO\n    enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };\n\n    enum { WARPS_M = Cta_tile::WARPS_M };\n    enum { WARPS_N = Cta_tile::WARPS_N };\n    enum { WARPS_K = Cta_tile::WARPS_K };\n\n    static_assert(WARPS_K == 1);\n    inline __device__ Smem_tile_mma(char *smem, int tidx) {\n        smem_ = __nvvm_get_smem_pointer(smem);\n\n        int write_col, write_row;\n        static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);\n        if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {\n            write_row = (tidx & 0x1c) / 4;\n            write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);\n        } else {\n            write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;\n            write_col = (tidx & 0x03);\n        }\n        write_col ^= (write_row & 0x07) * 4;\n\n        write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;\n    }\n\n    template<int M, int N>\n    inline __device__ void store(const uint4 (&regs)[M][N]) {\n        static_assert(COLS == Cta_tile::N);\n        for( int mi = 0; mi < M; mi++ ) {\n            for( int ni = 0; ni < N; ni++ ) {\n                size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;\n                fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);\n                fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);\n                offset ^= 4 * BYTES_PER_STS;\n                fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);\n                fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);\n            }\n        }\n    }\n\n    uint32_t smem_;\n    uint32_t write_offset_;\n    uint32_t warp_m;\n    uint32_t warp_n;\n    uint32_t lane;\n};\n\ntemplate< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>\nstruct Smem_tile_mma_transposed : public Base {\n    enum { BYTES_PER_LDS = 16 };\n    enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };\n    enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };\n    enum { WARPS_M = Base::WARPS_M };\n    enum { WARPS_N = Base::WARPS_N };\n    static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));\n    using Fragment = typename Base::Fragment;\n    inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {\n\n        static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));\n        int read_row, read_col;\n        read_row = (tidx & 0x0f);\n        read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;\n\n        read_col ^= (read_row & 0x07);\n        read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n    }\n\n    template<int M, int N>\n    inline __device__ void load(Fragment (&frag)[M][N]) {\n        static_assert(Base::COLS == Cta_tile::N);\n        for( int mi = 0; mi < M; mi++ ) {\n            for( int ni = 0; ni < N; ni++ ) {\n                size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;\n                uint4 dst;\n                fmha::ldsmt(dst, this->smem_ + offset);\n                frag[mi][ni].reg(0) = dst.x;\n                frag[mi][ni].reg(1) = dst.z;  // Fragment A regs col major!\n                frag[mi][ni].reg(2) = dst.y;\n                frag[mi][ni].reg(3) = dst.w;\n            }\n        }\n    }\n\n    uint32_t read_offset_;\n};\n\ntemplate< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>\nstruct Smem_tile_mma_epilogue : public Base {\n    enum { BYTES_PER_LDS = 16 };\n    enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };\n    enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };\n    enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };\n    static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);\n    enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };\n    static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);\n    enum { WARPS_M = Base::WARPS_M };\n    enum { WARPS_N = Base::WARPS_N };\n    static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);\n    \n    using Acc = fmha::Fragment_accumulator;\n\n    inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {\n        const int read_row = tidx / THREADS_PER_ROW;\n        int read_col = tidx % THREADS_PER_ROW;\n        read_col ^= (read_row & 0x07);\n        read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n    }\n\n    inline __device__ void load(uint4 (&data)[NUM_LDS]) {\n        for( int ii = 0; ii < NUM_LDS; ii++ ) {\n            size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;\n            fmha::lds(data[ii], this->smem_ + offset);\n        }\n    }\n\n    template<int M, int N>\n    inline __device__ void store(const Acc (&acc)[M][N]){\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                // 1st row - 4 elements per row.\n                float tmp00 = acc[mi][ni].elt(0);\n                float tmp01 = acc[mi][ni].elt(1);\n                float tmp02 = acc[mi][ni].elt(4);\n                float tmp03 = acc[mi][ni].elt(5);\n                // 2nd row - 4 elements per row.\n                float tmp10 = acc[mi][ni].elt(2);\n                float tmp11 = acc[mi][ni].elt(3);\n                float tmp12 = acc[mi][ni].elt(6);\n                float tmp13 = acc[mi][ni].elt(7);\n\n                uint32_t x = fmha::float2_to_half2(tmp00, tmp01);\n                uint32_t y = fmha::float2_to_half2(tmp02, tmp03);\n                uint32_t z = fmha::float2_to_half2(tmp10, tmp11);\n                uint32_t w = fmha::float2_to_half2(tmp12, tmp13);\n     \n                size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);\n                offset ^= 4 * Base::BYTES_PER_STS;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);\n            }\n        }\n    }\n\n\n\n    template<int M, int N>\n    inline __device__ void store(const uint4 (&regs)[M][N]) {\n        for( int mi = 0; mi < M; mi++ ) {\n            for( int ni = 0; ni < N; ni++ ) {\n                size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);\n                offset ^= 4 * Base::BYTES_PER_STS;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);\n            }\n        }\n    }\n\n    uint32_t read_offset_;\n};\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha/softmax.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Sum_ {\n    enum { IS_SUM = 1 };\n    static inline __device__ float apply(float x, float y) {\n        return x + y;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Max_ {\n    enum { IS_SUM = 0 };\n    static inline __device__ float apply(float x, float y) {\n        return x > y ? x : y;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ float apply_exp_(float x, float max) {\n    return __expf(x - max);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile, typename Kernel_traits>\nstruct Softmax_base {\n\n    // The Mma tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    // The number of MMAs in M/N dimensions.\n    enum { MMAS_M = Mma_tile::MMAS_M };\n    enum { MMAS_N = Mma_tile::MMAS_N };\n\n    // The number of groups of warp such that we have at most 4 warps writing consecutive elements.\n    enum { GROUPS = fmha::Div_up<Cta_tile::WARPS_N, 4>::VALUE };\n    // The number of elements that we are going to store per row.\n    enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS };\n    // The number of rows.\n    enum { ROWS = Cta_tile::M * GROUPS };\n    // The total number of elements.\n    enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW };\n\n    // Ctor.\n    template<typename Params>\n    inline __device__ Softmax_base(const Params &params, void *smem, int bidb, int tidx)\n        :  // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),\n          smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {\n\n        // Move to the 1st mask loaded by the thread+ tidx;\n        // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);\n\n        // Extract the position in the warp.\n        int warp = tidx / Cta_tile::THREADS_PER_WARP;\n        int lane = tidx % Cta_tile::THREADS_PER_WARP;\n\n        // Decompose the warp index into M and N.\n        int warp_m = warp % Cta_tile::WARPS_M;\n        int warp_n = warp / Cta_tile::WARPS_M;\n\n        // Decompose the warp-n index into group/position-inside-the-group.\n        int warp_g = warp_n / ELEMENTS_PER_ROW;\n        int warp_i = warp_n % ELEMENTS_PER_ROW;\n\n        // The location written by the threads.\n        int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;\n        int write_col = warp_i;\n\n        // Assemble the write pointer.\n        smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];\n\n        // Assemble the read pointer.\n        smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];\n    }\n\n    template<typename Mask>\n    inline __device__ void apply_mask(const Mask &mask) {\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ++ii ) {\n                #pragma unroll\n                for( int ni = 0; ni < MMAS_N; ++ni ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; ++jj ) {\n                        if( !mask.is_valid(mi, ni, ii, jj) ) {\n                            elt_[2 * mi + ii][4 * ni + jj] = -INFINITY;\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    // Apply the exp to all the elements.\n    inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N * 4; ++ni ) {\n                elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);\n            }\n        }\n    }\n\n    // Do a CTA-wide reduction.\n    template<typename Functor>\n    inline __device__ void reduce_1x4(float (&dst)[MMAS_M * 2]) {\n\n#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        if( Functor::IS_SUM ) {\n            // Apply the summation inside the thread.\n            float tmp[MMAS_M * 2][2];\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                tmp[mi][0] = 0.f;\n                tmp[mi][1] = 0.f;\n                #pragma unroll\n                for( int ni = 0; ni < MMAS_N; ++ni ) {\n                    tmp[mi][0] += elt_[mi][4 * ni + 0];\n                    tmp[mi][0] += elt_[mi][4 * ni + 1];\n                    tmp[mi][1] += elt_[mi][4 * ni + 2];\n                    tmp[mi][1] += elt_[mi][4 * ni + 3];\n                }\n                dst[mi] = tmp[mi][0] + tmp[mi][1];\n            }\n        } else\n#endif  // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        {\n            // Apply the functor for each row inside a thread.\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                dst[mi] = elt_[mi][0];\n                #pragma unroll\n                for( int ni = 1; ni < MMAS_N * 4; ++ni ) {\n                    dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);\n                }\n            }\n        }\n\n        // Apply the functor for each row inside each group of 4 threads.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));\n            __syncwarp();\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));\n            __syncwarp();\n        }\n\n        // Store the different values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            if( tidx_ % 4 == 0 ) {\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];\n            }\n        }\n\n        // Make sure the values are in shared memory.\n        __syncthreads();\n\n        // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the\n        // float4.\n        float4 tmp[1];\n        if( tidx_ < Cta_tile::M ) {\n            tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];\n        }\n\n        // Compute the reduction of those 8 values in a binary-tree fashion.\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);\n        tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);\n\n        // Make sure we can write to shared memory.\n        __syncthreads();\n\n        // Store the value back to shared memory.\n        if( tidx_ < Cta_tile::M ) {\n            smem_[tidx_] = tmp[0].x;\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Finally read the values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];\n            dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];\n        }\n    }\n\n    // Do a CTA-wide reduction.\n    template<typename Functor>\n    inline __device__ void reduce_1x8(float (&dst)[MMAS_M * 2]) {\n\n#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        if( Functor::IS_SUM ) {\n            // Apply the summation inside the thread.\n            float tmp[MMAS_M * 2][2];\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                tmp[mi][0] = 0.f;\n                tmp[mi][1] = 0.f;\n                #pragma unroll\n                for( int ni = 0; ni < MMAS_N; ++ni ) {\n                    tmp[mi][0] += elt_[mi][4 * ni + 0];\n                    tmp[mi][0] += elt_[mi][4 * ni + 1];\n                    tmp[mi][1] += elt_[mi][4 * ni + 2];\n                    tmp[mi][1] += elt_[mi][4 * ni + 3];\n                }\n                dst[mi] = tmp[mi][0] + tmp[mi][1];\n            }\n        } else\n#endif  // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        {\n            // Apply the functor for each row inside a thread.\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                dst[mi] = elt_[mi][0];\n                #pragma unroll\n                for( int ni = 1; ni < MMAS_N * 4; ++ni ) {\n                    dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);\n                }\n            }\n        }\n\n        // Apply the functor for each row inside each group of 4 threads.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));\n            __syncwarp();\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));\n            __syncwarp();\n        }\n\n        // Store the different values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            if( tidx_ % 4 == 0 ) {\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];\n            }\n        }\n\n        // Make sure the values are in shared memory.\n        __syncthreads();\n\n        // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the\n        // float4.\n        float4 tmp[2];\n        if( tidx_ < Cta_tile::M ) {\n            tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];\n            tmp[1] = reinterpret_cast<const float4 *>(&smem_[1 * ELEMENTS / 2])[tidx_];\n        }\n\n        // Compute the reduction of those 8 values in a binary-tree fashion.\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);\n        tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);\n        tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y);\n        tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w);\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);\n        tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z);\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x);\n\n        // Make sure we can write to shared memory.\n        __syncthreads();\n\n        // Store the value back to shared memory.\n        if( tidx_ < Cta_tile::M ) {\n            smem_[tidx_] = tmp[0].x;\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Finally read the values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];\n            dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];\n        }\n    }\n\n    // Do a CTA-wide reduction.\n    template<typename Functor>\n    inline __device__ void reduce(float (&dst)[MMAS_M * 2]) {\n        static_assert(Cta_tile::WARPS_M == 1 && (Cta_tile::WARPS_N == 4 || Cta_tile::WARPS_N == 8));\n        if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4 ) {\n            reduce_1x4<Functor>(dst);\n        } else if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8 ) {\n            reduce_1x8<Functor>(dst);\n        } else {\n            assert(false);\n        }\n\n        // Make sure we are done reading from shared memory.\n        __syncthreads();\n    }\n\n    // Scale all the elements.\n    inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {\n        // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.\n        float inv_sum[MMAS_M * 2];\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];\n        }\n\n        // Update the values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N * 4; ++ni ) {\n                elt_[mi][ni] *= inv_sum[mi];\n            }\n        }\n    }\n\n    // The pointer to the mask.\n    const char *packed_mask_ptr_;\n    // Shared memory for the CTA-wide reduction.\n    float *smem_, *smem_write_, *smem_read_;\n    // The current thread index.\n    int tidx_;\n    // The elements.\n    float elt_[MMAS_M * 2][MMAS_N * 4];\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile, typename Kernel_traits>\nstruct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {\n\n    // The base class.\n    using Base = Softmax_base<Cta_tile, Kernel_traits>;\n    // The fragment.\n    using Fragment_a = fmha::Fragment_a<fmha::Row>;\n\n    static_assert(Fragment_a::NUM_REGS == 4);\n\n    // The MMAs.\n    enum { MMAS_M = Base::MMAS_M };\n    enum { MMAS_N = Base::MMAS_N };\n\n    // The accumulators.\n    using Accumulator = fmha::Fragment_accumulator;\n    using Accumulator_out = Fragment<uint16_t, 8>;\n    static_assert(Accumulator_out::NUM_REGS == 4);\n\n    static_assert(std::is_same<Accumulator::Data_type, float>::value);\n\n    // Ctor.\n    template<typename Params>\n    inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx)\n        : Base(params, smem, bidb, tidx), params_scale_bmm1_(params.scale_bmm1) {\n    }\n\n    // Store the tile after softmax.\n    template<typename Gmem_tile>\n    inline __device__ void store(Gmem_tile &gmem_tile) {\n        Accumulator_out acc[MMAS_M][MMAS_N];\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N; ++ni ) {\n\n                // The elements.\n                float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0];\n                float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1];\n                float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2];\n                float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3];\n                float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0];\n                float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1];\n                float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2];\n                float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3];\n\n                // Transform to accumulators.\n                acc[mi][ni].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);\n                acc[mi][ni].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);\n                acc[mi][ni].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);\n                acc[mi][ni].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);\n            }\n        }\n\n        // Delegate to the gmem tile to store.\n        gmem_tile.store(acc);\n    }\n\n    // Pack the data to a fragment for the next GEMM.\n    template<int K, int M>\n    inline __device__ void pack(Fragment_a (&dst)[K][M]) const {\n        #pragma unroll\n        for( int mi = 0; mi < M; ++mi ) {\n            #pragma unroll\n            for( int ki = 0; ki < K; ++ki ) {\n\n                // 1st row - 4 elements per row.\n                float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];\n                float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];\n                float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];\n                float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];\n\n                // 2nd row - 4 elements per row.\n                float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];\n                float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];\n                float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];\n                float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];\n\n                // Pack to 4 registers.\n                dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);\n                dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);\n                dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);\n                dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);\n            }\n        }\n    }\n\n    // Scale FP32 fragments\n    inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {\n        const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);\n\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N; ++ni ) {\n                // 1st row - 4 elements per row.\n                this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;\n                this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;\n                this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;\n                this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;\n                // 2nd row - 4 elements per row.\n                this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;\n                this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;\n                this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;\n                this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;\n            }\n        }\n    }\n    const uint32_t params_scale_bmm1_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha/utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdint.h>\n#include <stdlib.h>\n\nextern \"C\" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Row {};  \nstruct Col {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int M, bool = (M & (M-1)) == 0 >\nstruct Next_power_of_two {\n};\n\ntemplate< int M >\nstruct Next_power_of_two<  M, true > { enum { VALUE =   M }; };\ntemplate<>\nstruct Next_power_of_two<  3, false> { enum { VALUE =   4 }; };\ntemplate<>\nstruct Next_power_of_two<  5, false> { enum { VALUE =   8 }; };\ntemplate<>\nstruct Next_power_of_two<  6, false> { enum { VALUE =   8 }; };\ntemplate<>\nstruct Next_power_of_two<  7, false> { enum { VALUE =   8 }; };\ntemplate<>\nstruct Next_power_of_two<  9, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 10, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 11, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 12, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 13, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 14, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 15, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 24, false> { enum { VALUE =  32 }; };\ntemplate<>\nstruct Next_power_of_two< 48, false> { enum { VALUE =  64 }; };\ntemplate<>\nstruct Next_power_of_two< 80, false> { enum { VALUE = 128 }; };\ntemplate<>\nstruct Next_power_of_two< 96, false> { enum { VALUE = 128 }; };\ntemplate<>\nstruct Next_power_of_two<112, false> { enum { VALUE = 128 }; };\ntemplate<>\nstruct Next_power_of_two<144, false> { enum { VALUE = 256 }; };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, bool = (N & (N-1)) == 0 >\nstruct Prev_power_of_two {\n};\n\ntemplate< int N >\nstruct Prev_power_of_two< N, true > { enum { VALUE = N }; };\ntemplate<>\nstruct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; };\ntemplate<>\nstruct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; };\ntemplate<>\nstruct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; };\ntemplate<>\nstruct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int M, int N >\nstruct Div_up {\n    enum { VALUE = (M + N-1) / N };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int A, int B >\nstruct Max {\n    enum { VALUE = A >= B ? A : B };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int A, int B, int C >\nstruct Max_3 {\n    enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int A, int B >\nstruct Min {\n    enum { VALUE = A <= B ? A : B };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int SIZE_IN_BYTES >\nstruct Uint_from_size_in_bytes {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<1> {\n    using Type = uint8_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<2> {\n    using Type = uint16_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<4> {\n    using Type = uint32_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<8> {\n    using Type = uint2;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<16> {\n    using Type = uint4;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int WARPS_M, int WARPS_N, int WARPS_K >\nstruct Warp_masks {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; };\ntemplate<>\nstruct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; };\ntemplate<>\nstruct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; };\ntemplate<>\nstruct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; };\ntemplate<>\nstruct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; };\ntemplate<>\nstruct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; };\ntemplate<>\nstruct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; };\ntemplate<>\nstruct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename T >\ninline __device__ __host__ T div_up(T m, T n) {\n    return (m + n-1) / n;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline int clz(int x) {\n    for( int i = 31; i >= 0; --i ) {\n        if( (1 << i) & x ) {\n            return 31 - i;\n        }\n    }\n    return 32;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline int find_log_2(int x, bool round_up = false) {\n    int a = 31 - clz(x);\n    if( round_up ) {\n        a += (x & (x-1)) ? 1 : 0;\n    }\n    return a;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {\n    uint32_t c;\n    asm volatile(\"add.f16x2 %0, %1, %2;\\n\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {\n    uint32_t c;\n    asm volatile(\"min.f16x2 %0, %1, %2;\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {\n    uint32_t c;\n    asm volatile(\"mul.f16x2 %0, %1, %2;\\n\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hmul4(uint2 a, uint2 b) {\n    uint2 c;\n    c.x = hmul2(a.x, b.x);\n    c.y = hmul2(a.y, b.y);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hmul8(uint4 a, uint4 b) {\n    uint4 c;\n    c.x = hmul2(a.x, b.x);\n    c.y = hmul2(a.y, b.y);\n    c.z = hmul2(a.z, b.z);\n    c.w = hmul2(a.w, b.w);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hmul8(uint32_t a, uint4 b) {\n    uint4 c;\n    c.x = hmul2(a, b.x);\n    c.y = hmul2(a, b.y);\n    c.z = hmul2(a, b.z);\n    c.w = hmul2(a, b.w);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {\n    uint32_t res;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    asm volatile( \"max.f16x2 %0, %1, %2;\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(lb));\n#else\n    const uint32_t zero = 0u;\n    asm volatile( \\\n        \"{\\n\" \\\n        \"\\t .reg .f16x2 sela;\\n\" \\\n        \"\\t set.gtu.u32.f16x2 sela, %1, %2;\\n\" \\\n        \"\\t and.b32 %0, sela, %1;\\n\" \n        \"}\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(zero));\n#endif\n    return res;\n}\nstatic inline __device__ uint32_t habs2(uint32_t x) {\n    uint32_t res;\n    asm volatile( \"abs.f16x2 %0, %1;\\n\" : \"=r\"(res) : \"r\"(x));\n    return res;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\ntemplate< typename T >\nstatic inline __device__ T clamp(T x, T lb, T ub) {\n    return x < lb ? lb : (x > ub ? ub : x);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t clamp_to_zero(uint16_t x) {\n    uint16_t mask;\n    asm volatile(\"set.gtu %0, %1, 0;\" : \"=h\"(mask) : \"h\"(x));\n    return mask & x;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t float_to_half(float f) {\n    uint16_t h;\n    asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(h) : \"f\"(f));\n    return h;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float2_to_half2(float a, float b) {\n    uint32_t c;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    asm volatile(\"cvt.rn.f16x2.f32 %0, %1, %2;\\n\" : \"=r\"(c) : \"f\"(b), \"f\"(a));\n#else\n    uint16_t lo = float_to_half(a);\n    uint16_t hi = float_to_half(b);\n    asm volatile(\"mov.b32 %0, {%1, %2};\\n\" : \"=r\"(c) : \"h\"(lo), \"h\"(hi));\n#endif\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float_to_half2(float a) {\n    return float2_to_half2(a,a);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float2_to_half2(const float2 &f) {\n    return float2_to_half2(f.x, f.y);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) {\n    uint2 d;\n    d.x = float2_to_half2(x, y);\n    d.y = float2_to_half2(z, w);\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {\n    uint32_t d;\n    asm volatile(\"fma.rn.f16x2 %0, %1, %2, %3;\\n\" : \"=r\"(d) : \"r\"(a), \"r\"(b), \"r\"(c));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) {\n    uint32_t d;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    asm volatile(\"fma.rn.f16x2.relu %0, %1, %2, %3;\" : \"=r\"(d) : \"r\"(a), \"r\"(b), \"r\"(c));\n#else\n    d = hrelu2(hfma2(a, b, c));\n#endif\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t h0_h0(uint32_t x) {\n    uint32_t y;\n    asm volatile(\"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\\n\" \n        : \"=r\"(y) : \"r\"(x)); \n    return y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float h0_to_float(uint32_t h2) {\n    float f;\n    asm volatile(\"{\\n\" \\\n        \".reg .f16 lo, hi;\\n\" \\\n        \"mov.b32 {lo, hi}, %1;\\n\" \\\n        \"cvt.f32.f16 %0, lo;\\n\" \\\n        \"}\\n\" : \"=f\"(f) : \"r\"(h2));\n    return f;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t h1_h1(uint32_t x) {\n    uint32_t y;\n    asm volatile(\"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\\n\" \n        : \"=r\"(y) : \"r\"(x)); \n    return y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {\n    uint16_t d;\n    asm volatile(\"add.f16 %0, %1, %2;\" : \"=h\"(d) : \"h\"(a), \"h\"(b));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hadd(uint32_t a, uint32_t b) {\n    return hadd2(a, b);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hadd4(uint2 a, uint2 b) {\n    uint2 c;\n    c.x = hadd2(a.x, b.x);\n    c.y = hadd2(a.y, b.y);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hadd(uint2 a, uint2 b) {\n    return hadd4(a, b);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hadd8(uint4 a, uint4 b) {\n    uint4 c;\n    c.x = hadd2(a.x, b.x);\n    c.y = hadd2(a.y, b.y);\n    c.z = hadd2(a.z, b.z);\n    c.w = hadd2(a.w, b.w);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 fadd4(uint4 a, uint4 b) {\n    float4 c;\n    c.x = reinterpret_cast<const float&>(a.x) + reinterpret_cast<const float&>(b.x);\n    c.y = reinterpret_cast<const float&>(a.y) + reinterpret_cast<const float&>(b.y);\n    c.z = reinterpret_cast<const float&>(a.z) + reinterpret_cast<const float&>(b.z);\n    c.w = reinterpret_cast<const float&>(a.w) + reinterpret_cast<const float&>(b.w);\n    return reinterpret_cast<const uint4&>(c);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hadd(uint4 a, uint4 b) {\n    return hadd8(a, b);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float half_to_float(uint16_t h) {\n    float f;\n    asm volatile(\"cvt.f32.f16 %0, %1;\\n\" : \"=f\"(f) : \"h\"(h));\n    return f;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float2 half2_to_float2(uint32_t x) {\n    uint16_t lo, hi;\n    asm volatile(\"mov.b32 {%0, %1}, %2;\\n\" : \"=h\"(lo), \"=h\"(hi) : \"r\"(x));\n    return make_float2(half_to_float(lo), half_to_float(hi));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) {\n    float2 tmp = half2_to_float2(h);\n    x = tmp.x;\n    y = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {\n    uint16_t d;\n    asm volatile(\"fma.rn.f16 %0, %1, %2, %3;\" : \"=h\"(d) : \"h\"(a), \"h\"(b), \"h\"(c));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {\n    uint16_t d;\n    asm volatile(\"mul.f16 %0, %1, %2;\" : \"=h\"(d) : \"h\"(a), \"h\"(b));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float sigmoid(float x) {\n    return 1.f / (1.f + expf(-x));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint16_t &dst) {\n    dst = uint16_t(0);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint32_t &dst) {\n    dst = 0u;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint2 &dst) {\n    dst = make_uint2(0u, 0u);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint4 &dst) {\n    dst = make_uint4(0u, 0u, 0u, 0u);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// P R E D I C A T E   P A C K I N G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\nenum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE };\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// G E N E R I C   P R E D I C A T E D   L D G S T S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M, typename Functor >\ninline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {\n\n    // The number of complete bytes (where we use all the predicates in a byte).\n    enum { COMPLETE = N / PREDS_PER_BYTE };\n    // Make sure we did allocate enough predicates.\n    static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, \"\");\n    // The remainder.\n    enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };\n    // Make sure we got the math right and the remainder is between 0 and 3.\n    static_assert(REMAINDER >= 0 && REMAINDER <= 3, \"\");\n    // The mask to extract the predicates.\n    enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };\n\n    // Clear the fetch registers.\n    #pragma unroll\n    for( int ii = 0; ii < N; ++ii ) {\n        fct.clear(ii);\n    }\n\n    // Run complete steps.\n    bool p[PREDS_PER_BYTE];\n    #pragma unroll\n    for( int ii = 0; ii < COMPLETE; ++ii ) {\n\n        // The predicate.\n        uint32_t reg = preds[ii / BYTES_PER_REG];\n\n        // Extract the predicates.\n        #pragma unroll\n        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {\n            uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);\n            p[jj] = (reg & mask) != 0u;\n        }\n\n        // Issue the loads.\n        #pragma unroll\n        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {\n            fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);\n        }\n    }\n\n    // Skip the rest of the code if we do not have a remainder.\n    if( REMAINDER > 0 ) {\n\n        // The mask to extract the predicates.\n        enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };\n\n        // The predicate register.\n        uint32_t reg = preds[COMPLETE / BYTES_PER_REG];\n\n        // Extract the predicates.\n        #pragma unroll\n        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {\n            uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);\n            p[jj] = (reg & mask) != 0u;\n        }\n\n        // Issue the loads.\n        #pragma unroll\n        for( int ii = 0; ii < REMAINDER; ++ii ) {\n            fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int M, typename Functor >\ninline __device__ void load_(Functor &fct, uint32_t preds) {\n    uint32_t tmp[1] = { preds };\n    load_<M>(fct, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint8_t &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint8_t*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint16_t &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint16_t*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint32_t &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint32_t*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint2 &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint2*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint4 &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint4*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type, int N >\nstruct Ldg_functor {\n    // Ctor.\n    inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N])\n        : fetch_(fetch), ptrs_(ptrs) {\n    }\n\n    // Clear the element.\n    inline __device__ void clear(int ii) {\n        fmha::clear(fetch_[ii]);\n    }\n\n    // Trigger the loads.\n    inline __device__ void load(int ii, bool p) {\n        if( p ) {\n            ldg(fetch_[ii], ptrs_[ii]);\n        }\n    }\n\n    // The fetch registers.\n    Data_type (&fetch_)[N];\n    // The pointers.\n    const void* (&ptrs_)[N];\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type, int N, int M >\ninline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    Ldg_functor<Data_type, N> fct(fetch, ptrs);\n    load_<N>(fct, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint8_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint16_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint32_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint2, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint4, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint16_t &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.b16 %0, [%1];\\n\" : \"=h\"(dst) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint32_t &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.b32 %0, [%1];\\n\" : \"=r\"(dst) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint2 &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.v2.b32 {%0, %1}, [%2];\\n\" : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint4 &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\\n\"\n        : \"=r\"(dst.x)\n        , \"=r\"(dst.y)\n        , \"=r\"(dst.z)\n        , \"=r\"(dst.w)\n        :  \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D S M\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\\n\"\n        : \"=r\"(dst) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\\n\"\n        : \"=r\"(dst) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint2 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint2 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint4 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y), \"=r\"(dst.z), \"=r\"(dst.w) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint4 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y), \"=r\"(dst.z), \"=r\"(dst.w) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// S T G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint8_t val) {\n    *reinterpret_cast<uint8_t*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint16_t val) {\n    *reinterpret_cast<uint16_t*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint32_t val) {\n    *reinterpret_cast<uint32_t*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint2 val) {\n    *reinterpret_cast<uint2*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint4 val) {\n    *reinterpret_cast<uint4*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// S T S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint16_t val) {\n    asm volatile(\"st.shared.b16 [%0], %1;\\n\" : : \"r\"(ptr), \"h\"(val));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint32_t val) {\n    asm volatile(\"st.shared.b32 [%0], %1;\\n\" : : \"r\"(ptr), \"r\"(val));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint2 val) {\n    asm volatile(\"st.shared.v2.b32 [%0], {%1, %2};\\n\"\n        :\n        : \"r\"(ptr)\n        , \"r\"(val.x)\n        , \"r\"(val.y));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint4 val) {\n    asm volatile(\"st.shared.v4.b32 [%0], {%1, %2, %3, %4};\\n\"\n        :\n        : \"r\"(ptr)\n        , \"r\"(val.x)\n        , \"r\"(val.y)\n        , \"r\"(val.z)\n        , \"r\"(val.w));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type, int N >\ninline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) {\n    #pragma unroll\n    for( int ii = 0; ii < N; ++ii ) {\n        sts(ptrs[ii], data[ii]);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) {\n    sts_<uint16_t, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) {\n    sts_<uint32_t, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) {\n    sts_<uint2, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {\n    sts_<uint4, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <cuda.h>\n#include <vector>\n\n#include <ATen/CUDAGeneratorImpl.h>\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n\n#include <fmha_utils.h>\n\n\nconstexpr int TOTAL_DIM = 0;\nconstexpr int THREE_DIM = 1;\nconstexpr int H_DIM = 2;\nconstexpr int D_DIM = 3;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Qkv_params {\n    // The QKV matrices.\n    void *qkv_ptr;\n\n    // The stride between rows of the Q, K and V matrices.\n    size_t qkv_stride_in_bytes;\n\n    // The number of heads.\n    int h;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fused_multihead_attention_fprop_params : public Qkv_params {\n\n    // The dQKV matrices.\n    void *dqkv_ptr;\n\n    // Temporary for dKV.\n    void *dkv_ptr;\n\n    // The O matrix (output).\n    void *o_ptr;\n\n    // The stride between rows of O.\n    int64_t o_stride_in_bytes;\n\n    // The pointer to the S matrix, overwritten by the dP matrix (bwd).\n    void *s_ptr;\n    // The stride between rows of the S matrix.\n    int64_t s_stride_in_bytes;\n\n    // The dimensions.\n    int b, s, d;\n\n    // The scaling factors for the kernel.\n    uint32_t scale_bmm1, scale_softmax, scale_bmm2;\n\n    // array of length b+1 holding starting offset of each sequence.\n    int *cu_seqlens;\n\n    // The dropout probability (probability of keeping an activation).\n    float p_dropout;\n\n    // Scale factor of 1 / (1 - p_dropout).\n    float rp_dropout;\n\n    // Scale factor of 1 / (1 - p_dropout), in half2.\n    uint32_t scale_dropout;\n\n    // Random state.\n    at::PhiloxCudaState philox_args;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\nvoid run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\nvoid run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\nvoid run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\n\nvoid run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\n\nvoid run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream); \n\nvoid run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream);\n\nvoid fmha_run_noloop_reduce(void *out,\n                            const void *in,\n                            const int *cu_seqlens,\n                            const int hidden_size,\n                            const int batch_size,\n                            const int total,\n                            const int num_chunks,\n                            cudaStream_t stream);\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 128 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_128_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 256 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_256_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 384 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_384_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n#include \"fmha_dgrad_kernel_1xN_reload_nl.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\ntemplate<int CHUNKS>\n__global__\nvoid fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){\n    fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 512 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n\nvoid run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 512 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;\n       \n    if( num_chunks == 2 ) {\n        kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;\n    }else if( num_chunks == 3 ) {\n        kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;\n    } else {\n        assert(false && \"Unsupperted number of chunks\");\n    }\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b, num_chunks);\n\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n\n    FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, typename Params>\ninline __device__ void compute_dv_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dv =\n        fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n    static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);\n    static_assert(Cta_tile_dv::N == 64);\n    static_assert(Cta_tile_dv::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    // using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n    using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n    // The shared memory tile to reload Q as fragment b.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store dV.\n    using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle dV.\n    using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;\n    static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);\n    static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n    using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n    using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_do gmem_q(params, binfo, tidx);  // treating dout as Q\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 2, binfo, tidx);  // treating V as K\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];\n    static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);\n    static_assert(Mma_tile_dv::MMAS_K == 1);\n    smem_qt.load(frag_qt[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n    Softmax softmax(\n        params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);\n\n    enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n    // Load over the entire sequence length.\n    for( int l = 0; l < STEPS; l++ ) {\n        const int loop = l * Cta_tile_p::M;\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Load S\n        uint4 s_regs[M][N];\n        gmem_s.load(s_regs, mask);\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n        // Do this part of P^T = (Q * K^T)^T.\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Store s * dmask to smem for transpose\n        smem_s.store(s_regs);\n\n        // Declare the accumulators for the 1st gemm.\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n        // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe\n        if( l < STEPS - 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n\n\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        float s_mat[2 * M][4 * N];\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);\n            }\n        }\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                        float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];\n                        const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;\n                        const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;\n                        s_dmask = fabsf(s_dmask);\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);\n                    }\n                }\n            }\n        }\n\n        float p_sum[2 * M];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        const float scalef = reinterpret_cast<const float &>(params.scale_softmax);\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;\n                    }\n                }\n            }\n        }\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];\n        smem_s.load(frag_s);\n        for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {\n            for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {\n                for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {\n                    frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);\n                    frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        gmem_s.store(softmax.elt_, mask);\n        gmem_s.move();\n\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dv::MMAS_K;\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n        // Commit the values for Q into shared memory.\n        if(l < STEPS - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure we are reading from the correct buffer.\n        smem_q.move_to_next_read_buffer();\n        smem_qt.move_to_next_read_buffer();\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_q.load(frag_q[0], 0);\n        smem_k.load(frag_k[0], 0);\n        smem_qt.load(frag_qt[0], 0);\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue swizzle for dV\n    Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);\n    smem_dv.store(acc_dv);\n\n    __syncthreads();\n    uint4 dv_out[Smem_tile_dv::NUM_LDS];\n    smem_dv.load(dv_out);\n    Qkv_params dv_params;\n    dv_params.qkv_ptr = params.dqkv_ptr;\n    dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;\n    dv_params.h = params.h;\n    Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);\n    gmem_dv.store(dv_out);\n}\n\ntemplate<typename Kernel_traits, typename Params>\ninline __device__ void compute_dq_dk_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dk =\n        fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n    static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);\n    static_assert(Cta_tile_dk::N == 64);\n    static_assert(Cta_tile_dk::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_v;  // K is used like V in fprop\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    // using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    using Gmem_tile_o = fmha::Gmem_tile_dq<Cta_tile_o>;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store dK.\n    using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle dK.\n    using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;\n    static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);\n    static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);\n\n    // The shared memory tile to reload Q transposed.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n\n\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n    static_assert(M == Mma_tile_o::MMAS_M);\n    static_assert(N == Mma_tile_o::MMAS_K);\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n    // Load dP\n    uint4 s_regs[M][N];\n    gmem_s.load(s_regs, mask);\n    gmem_s.move();\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];\n    smem_qt.load(frag_qt[0], 0);\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);\n\n    // Load over the entire sequence length.\n    for( int l=0;l<STEPS;l++) {\n        const int loop = l * Cta_tile_p::M;\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Pack dP as Fragment_a\n        fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                frag_p[ni][mi].reg(0) = dst.x;  // row 0, cols 0,1\n                frag_p[ni][mi].reg(1) = dst.z;  // row 8, cols 0,1\n                frag_p[ni][mi].reg(2) = dst.y;  // row 0, cols 8,9\n                frag_p[ni][mi].reg(3) = dst.w;  // row 8, cols 8,9\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T. dQ = dP x dK\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_o::MMAS_K;\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Store dP to smem for transpose\n        smem_s.store(s_regs);\n        if(l < STEPS - 1) {\n            // Load next part of S\n            gmem_s.load(s_regs, mask);\n            gmem_s.move();\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];\n        smem_s.load(frag_s);\n\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dk::MMAS_K;\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Commit the values for Q into shared memory.\n        if( l < STEPS - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_qt.load(frag_qt[0], 0);\n        smem_k.load(frag_k[0], 0);\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue swizzle for dK\n    Smem_tile_dk smem_dk(&smem_[0], tidx);\n    smem_dk.store(acc_dk);\n    __syncthreads();\n    uint4 dk_out[Smem_tile_dk::NUM_LDS];\n    smem_dk.load(dk_out);\n    Qkv_params dk_params;\n    dk_params.qkv_ptr = params.dqkv_ptr;\n    dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;\n    dk_params.h = params.h;\n    Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);\n    gmem_dk.store(dk_out);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int CHUNKS, typename Kernel_traits, typename Params>\ninline __device__ void compute_dv_1xN_nl(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dv = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n    static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);\n    static_assert(Cta_tile_dv::N == 64);\n    static_assert(Cta_tile_dv::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n    // The shared memory tile to reload Q as fragment b.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store dV.\n    using Gmem_tile_dv = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o, \n                                             fmha::BITS_PER_ELEMENT_B, \n                                             Cta_tile_p::N, //S, \n                                             Cta_tile_p::K, //D, \n                                             2*CHUNKS>;\n\n    // The shared memory tile to swizzle dV.\n    using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;\n    static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);\n    static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n    using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n    using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the chunk.\n    const int bidc = blockIdx.z;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n    fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_do gmem_q(params, binfo, tidx);  // treating dout as Q\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 2, binfo, tidx);  // treating V as K\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n    Noloop nl_traits(bidc);\n    nl_traits.move_all(gmem_q, gmem_s);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];\n    static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);\n    static_assert(Mma_tile_dv::MMAS_K == 1);\n    smem_qt.load(frag_qt[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n    Softmax softmax(\n        params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);\n\n    // Load over the entire sequence length.\n    for(int l = 0; l < nl_traits.num_steps_;l++) {\n        const int loop = nl_traits.offset_loop_count(l);\n        if( loop >= binfo.actual_seqlen ) break;\n\n        uint4 s_regs[M][N];\n        gmem_s.load(s_regs, mask);\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n        // Do this part of P^T = (Q * K^T)^T.\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n\n        smem_s.store(s_regs);\n\n        // Declare the accumulators for the 1st gemm.\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n        // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe\n        if(l < nl_traits.num_steps_ - 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        float s_mat[2 * M][4 * N];\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);\n            }\n        }\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                         float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];\n                        const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;\n                        const float d_s= drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;\n                        s_dmask = fabsf(s_dmask);\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask);\n                    }\n                }\n            }\n        }\n\n        float p_sum[2 * M];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        const float scalef = reinterpret_cast<const float &>(params.scale_softmax);\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;\n                    }\n                }\n            }\n        }\n\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];\n        smem_s.load(frag_s);\n        for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {\n            for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {\n                for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {\n                    frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);\n                    frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        gmem_s.store(softmax.elt_, mask);\n        gmem_s.move();\n\n        static_assert(Mma_tile_dv::MMAS_K == 1);  // DEBUG\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dv::MMAS_K;\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n        // Commit the values for Q into shared memory.\n        if(l < nl_traits.num_steps_ - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure we are reading from the correct buffer.\n        smem_q.move_to_next_read_buffer();\n        smem_qt.move_to_next_read_buffer();\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_q.load(frag_q[0], 0);\n        smem_k.load(frag_k[0], 0);\n        smem_qt.load(frag_qt[0], 0);\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue for dV = (S * D)' * dout'. We're fully exposed to this!\n\n    // Epilogue swizzle for dV\n    Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);\n    smem_dv.store(acc_dv);\n\n    __syncthreads();\n\n    uint4 dv_out[Smem_tile_dv::NUM_LDS];\n    smem_dv.load(dv_out);\n    Qkv_params dv_params;\n    dv_params.qkv_ptr = params.dkv_ptr;\n    dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);\n    dv_params.h = params.h;\n    Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx);\n    gmem_dv.store(dv_out);\n}\n\ntemplate<int CHUNKS, typename Kernel_traits, typename Params>\ninline __device__ void compute_dq_dk_1xN_nl(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dk = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n    static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);\n    static_assert(Cta_tile_dk::N == 64);\n    static_assert(Cta_tile_dk::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_v;  // K is used like V in fprop\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = Gmem_tile_dq<Cta_tile_o>;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store dK.\n    using Gmem_tile_dk = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o, \n                                             fmha::BITS_PER_ELEMENT_B, \n                                             Cta_tile_p::N, //S, \n                                             Cta_tile_p::K, //D, \n                                             2*CHUNKS>;\n\n    // The shared memory tile to swizzle dK.\n    using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;\n    static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);\n    static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);\n\n    // The shared memory tile to reload Q transposed.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n    // The global memory tile to load dP, stored in S\n    using Gmem_tile_s = Gmem_tile_mma_s<Cta_tile_p>;\n    // The shared memory tile to transpose dP.\n    using Smem_tile_st = Smem_tile_mma_transposed<Cta_tile_p>;  \n\n    using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n    static_assert(M == Mma_tile_o::MMAS_M);\n    static_assert(N == Mma_tile_o::MMAS_K);\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    const int bidc = blockIdx.z;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q (as B).\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    // Allocate the global memory tile loader for dP.\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n    // Allocate the shared memory tile loader for dP.\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    Noloop nl_traits(bidc);\n\n    nl_traits.move_all(gmem_q, gmem_o, gmem_s);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_qt);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    uint4 s_regs[M][N];\n    gmem_s.load(s_regs, mask);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_qt);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];\n    smem_qt.load(frag_qt[0], 0);\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    enum { THREADS_PER_ROW = 32 };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);\n\n    // Load over the entire sequence length.\n    for(int l=0;l < nl_traits.num_steps_; l++) {\n\n        // Pack dP as Fragment_a\n        fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                frag_p[ni][mi].reg(0) = dst.x;\n                frag_p[ni][mi].reg(1) = dst.z;\n                frag_p[ni][mi].reg(2) = dst.y;\n                frag_p[ni][mi].reg(3) = dst.w;\n            }\n        }\n        smem_s.store(s_regs);\n        if(l < nl_traits.num_steps_- 1) {\n            // Load next part of S\n            gmem_s.move();\n            gmem_s.load(s_regs, mask);\n            // Trigger the load for the next Q values.\n            smem_qt.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_qt);\n        }\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T. dQ = dP x dK\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_o::MMAS_K;\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        static_assert(Gmem_tile_o::LOOPS == 1); //DEBUG\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];\n        smem_s.load(frag_s);\n\n        static_assert(Mma_tile_dk::MMAS_K == 1);  // DEBUG\n\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dk::MMAS_K;\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Commit the values for Q into shared memory.\n        if(l < nl_traits.num_steps_- 1) {\n            gmem_q.commit(smem_qt);\n            __syncthreads();\n            // Trigger the loads for the values of Q for the next iteration.\n            smem_qt.load(frag_qt[0], 0);\n            smem_k.load(frag_k[0], 0);\n        }\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue for dK = dP' * dq. We're fully exposed to this!\n\n    // Epilogue swizzle for dK\n    Smem_tile_dk smem_dk(&smem_[0], tidx);\n    smem_dk.store(acc_dk);\n    \n    __syncthreads();\n    \n    uint4 dk_out[Smem_tile_dk::NUM_LDS];\n    smem_dk.load(dk_out);\n    Qkv_params dk_params;\n    dk_params.qkv_ptr = params.dkv_ptr;\n    dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);\n    dk_params.h = params.h;\n    Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx);\n    gmem_dk.store(dk_out);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_128_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_128_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\nvoid run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_128_64_sm80_train_kernel : &fmha_fprop_fp16_128_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_256_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_256_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\nvoid run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_256_64_sm80_train_kernel : &fmha_fprop_fp16_256_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN_reload_v.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_384_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_384_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\nvoid run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_384_64_sm80_train_kernel : &fmha_fprop_fp16_384_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_v + smem_size_o + smem_size_softmax;\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n#include \"fmha_fprop_kernel_1xN_nl.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_512_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\ntemplate<int CHUNKS>\n__global__ void fmha_fprop_fp16_512_64_sm80_train_nl_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN_nl<CHUNKS,Kernel_traits, true>(params);\n}\n\ntemplate<int CHUNKS>\n__global__ void fmha_fprop_fp16_512_64_sm80_predict_nl_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN_nl<CHUNKS, Kernel_traits, false>(params);\n}\n\n\nvoid run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n\nvoid run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2> : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;\n    if( num_chunks == 2 ) {\n        kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2>\n                             : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;\n    } else if( num_chunks == 3 ) {\n        kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<3>\n                             : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<3>;\n    } else if( num_chunks == 4 ) {\n        kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<4>\n                             : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<4>;\n    } else {\n        assert(false && \"Unsupported num_chunks\");\n    }\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b, num_chunks);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    auto seeds = at::cuda::philox::unpack(params.philox_args);\n\n    Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));\n\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for V.\n    Gmem_tile_v gmem_v(params, 2, binfo, tidx);\n    // The base pointer of smem_v;\n    char *smem_v_ = nullptr;\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];\n    } else {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];\n    }\n    // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!\n    Smem_tile_v smem_v(smem_v_, tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n    // Trigger the loads for K.\n    gmem_v.load(smem_v);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Commit the data for V to shared memory.\n    if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        gmem_v.commit(smem_v);\n    }\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n        smem_k.load(frag_k[ki], ki);\n    }\n\n    // Commit the data for V to shared memory if it has not been done already.\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        // Make sure we are done loading the fragments for K.\n        __syncthreads();\n\n        // Commit the data to shared memory for V.\n        gmem_v.commit(smem_v);\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n    }\n\n    // Load the fragments for V. We keep the data in registers during the entire kernel.\n    typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n        smem_v.load(frag_v[ki], ki);\n    }\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;\n    Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n\n    // Load over the entire sequence length.\n    for( int l = 0; l < STEPS; l++ ) {\n        const int loop = l * Cta_tile_p::M;\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n\n    // Do this part of P^T = (Q * K^T)^T.\n    #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Load the mask for that iteration.\n        mask.load(l);\n\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        // Apply the mask.\n        softmax.apply_mask(mask);\n\n        if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {\n            // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction\n            __syncthreads();\n        }\n        // Compute the max.\n        float p_max[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Max_>(p_max);\n\n        // Make sure we are done reading shared memory.\n        __syncthreads();\n\n        // Compute the exponential value.\n        softmax.apply_exp(p_max);\n\n        // Compute the sum.\n        float p_sum[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        // Finalize softmax on the accumulators of P^T.\n        softmax.scale(p_sum);\n\n        if( Is_training ) {\n            auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < 2; ii++ ) {\n                    #pragma unroll\n                    for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {\n                        float4 tmp = uniform4(ph());\n                        // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from\n                        // pre-existing zeros\n                        softmax.elt_[2 * mi + ii][4 * ni + 0] =\n                            encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 1] =\n                            encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 2] =\n                            encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 3] =\n                            encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);\n                    }\n                }\n            }\n            gmem_s.store(softmax.elt_, mask);\n            gmem_s.move();\n        }\n\n        // Trigger the load for the next Q values.\n        if(l < STEPS - 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n\n        using Frag_p = fmha::Fragment_a< fmha::Row>;\n        Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        softmax.pack(frag_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {\n                    //\"Apply\" the dropout.\n                    frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);\n                    frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T.\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);\n        }\n\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        // Commit the values for Q into shared memory.\n        if(l < STEPS - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_q.load(frag_q[0], 0);\n\n    }  // Outer loop over the sequence length.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int CHUNKS, typename Kernel_traits, bool Is_training, typename Params>\ninline __device__ void device_1xN_nl(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store S/D.\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    const int bidc = blockIdx.z;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    Noloop nl_traits(bidc);\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    auto seeds = at::cuda::philox::unpack(params.philox_args);\n\n    Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));\n\n    fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for V.\n    Gmem_tile_v gmem_v(params, 2, binfo, tidx);\n    // The base pointer of smem_v;\n    char *smem_v_ = nullptr;\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];\n    } else {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];\n    }\n    // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!\n    Smem_tile_v smem_v(smem_v_, tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    nl_traits.move_all(gmem_q, gmem_o, gmem_s);\n\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n    // Trigger the loads for K.\n    gmem_v.load(smem_v);\n\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Commit the data for V to shared memory.\n    if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        gmem_v.commit(smem_v);\n    }\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n        smem_k.load(frag_k[ki], ki);\n    }\n\n    // Commit the data for V to shared memory if it has not been done already.\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        // Make sure we are done loading the fragments for K.\n        __syncthreads();\n\n        // Commit the data to shared memory for V.\n        gmem_v.commit(smem_v);\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n    }\n\n    // Load the fragments for V. We keep the data in registers during the entire kernel.\n    typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n        smem_v.load(frag_v[ki], ki);\n    }\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n    Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);\n\n    // The number of threads per row.\n    enum { THREADS_PER_ROW = 32 };\n\n    // Load over the entire sequence length.\n    for(int l = 0; l < nl_traits.num_steps_;l++) {\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n\n        // Do this part of P^T = (Q * K^T)^T.\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Trigger the load for the next Q values.\n        if( l < nl_traits.num_steps_- 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n\n\n\n        // Load the mask for that iteration.\n        mask.load(nl_traits.loop_offset_ + l);\n\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        // Apply the mask.\n        softmax.apply_mask(mask);\n\n        if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {\n            // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction\n            __syncthreads();\n        }\n\n        // Compute the max.\n        float p_max[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Max_>(p_max);\n\n        // Make sure we are done reading shared memory.\n        __syncthreads();\n\n        // Compute the exponential value.\n        softmax.apply_exp(p_max);\n\n        // Compute the sum.\n        float p_sum[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        // Finalize softmax on the accumulators of P^T.\n        softmax.scale(p_sum);\n        if( Is_training ) {\n            auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < 2; ii++ ) {\n                    #pragma unroll\n                    for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {\n                        float4 tmp = uniform4(ph());\n                        // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros\n                        softmax.elt_[2 * mi + ii][4 * ni + 0] =\n                            encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 1] =\n                            encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 2] =\n                            encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 3] =\n                            encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);\n                    }\n                }\n            }\n            gmem_s.store(softmax.elt_, mask);\n            gmem_s.move();\n        }\n\n        using Frag_p = fmha::Fragment_a<fmha::Row>;\n        Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        softmax.pack(frag_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {\n                    //\"Apply\" the dropout.\n                    frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);\n                    frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T.\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);\n        }\n\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        // Commit the values for Q into shared memory.\n        if( l < nl_traits.num_steps_- 1) {\n            gmem_q.commit(smem_q);\n            __syncthreads();\n            smem_q.load(frag_q[0], 0);\n        }\n\n    }  // Outer loop over the sequence length.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    auto seeds = at::cuda::philox::unpack(params.philox_args);\n    Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));\n\n    static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[0], tidx);\n\n    // Allocate the global memory tile loader for V.\n    Gmem_tile_v gmem_v(params, 2, binfo, tidx);\n    // The base pointer of smem_v;\n    char *smem_v_ = nullptr;\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        smem_v_ = &smem_[0];\n    } else {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];\n    }\n    static_assert(Kernel_traits::SHARE_SMEM_FOR_K_AND_V);\n    static_assert(Smem_tile_k::BYTES_PER_TILE == Smem_tile_v::BYTES_PER_TILE);\n    // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!\n    Smem_tile_v smem_v(smem_v_, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n    // Trigger the loads for K.\n    gmem_v.load(smem_v);\n\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Commit the data for V to shared memory.\n    if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        gmem_v.commit(smem_v);\n    }\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[1][Mma_tile_p::MMAS_M];\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n        smem_k.load(frag_k[ki], ki);\n    }\n\n    // Commit the data for V to shared memory if it has not been done already.\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        // Make sure we are done loading the fragments for K.\n        __syncthreads();\n\n        // Commit the data to shared memory for V.\n        gmem_v.commit(smem_v);\n\n    }\n\n    enum { BITS_PER_ELT_S = sizeof(typename fmha::A_type) * 8 };\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;\n    Softmax softmax(params, &smem_[Smem_tile_v::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);\n\n    constexpr int SMEM_BYTES_SOFTMAX = Softmax::ELEMENTS * sizeof(float);\n    static_assert(SMEM_BYTES_SOFTMAX == Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float));\n\n    enum { THREADS_PER_ROW = 32 };\n\n    const float pinv = 1.f / params.p_dropout;\n\n    // Load over the entire sequence length.\n    for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[0], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[0], frag_k[ki]);\n        }\n\n        // Load the mask for that iteration.\n        mask.load(outer);\n\n        // Convert from the accumulator typ e to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        // Apply the mask.\n        softmax.apply_mask(mask);\n\n        static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);\n\n        // Compute the max.\n        float p_max[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Max_>(p_max);\n\n        // Make sure we are done reading shared memory.\n        __syncthreads();\n        // Compute the exponential value.\n        softmax.apply_exp(p_max);\n        // Compute the sum.\n        float p_sum[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        // Finalize softmax on the accumulators of P^T.\n        softmax.scale(p_sum);\n\n        __syncthreads();\n        if( Is_training ) {\n            auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < 2; ii++ ) {\n                    #pragma unroll\n                    for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {\n                        float4 tmp = uniform4(ph());\n                        // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from\n                        // pre-existing zeros\n                        softmax.elt_[2 * mi + ii][4 * ni + 0] =\n                            encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 1] =\n                            encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 2] =\n                            encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 3] =\n                            encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);\n                    }\n                }\n            }\n\n            gmem_s.store(softmax.elt_, mask);\n            gmem_s.move();\n        }\n\n        // Trigger the load for the next Q values.\n        if( loop + Cta_tile_p::M < Cta_tile_p::N ) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n        typename Smem_tile_v::Fragment frag_v[1][Mma_tile_o::MMAS_N];\n\n        using Frag_p = fmha::Fragment_a< fmha::Row>;\n        Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        softmax.pack(frag_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {\n                    //\"Apply\" the dropout.\n                    frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);\n                    frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of V values.\n            smem_v.load(frag_v[0], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_o, frag_p[ki], frag_v[0]);\n        }\n\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Always sync after last iter: shared smem_q and smem_o!\n            __syncthreads();\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n        // same smem as o\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        // Commit the values for Q into shared memory.\n        if( loop + Cta_tile_p::M < Cta_tile_p::N ) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n    }  // Outer loop over the sequence length.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_kernel.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <multihead_attn/philox.h>\n\n#include <fmha.h>\n#include <fmha/utils.h>\n#include <fmha/smem_tile.h>\n#include <fmha/gmem_tile.h>\n#include <fmha/mask.h>\n#include <fmha/softmax.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int THREADS_PER_CTA>\nstruct BlockInfoPadded {\n\n    template<typename Params>\n    __device__ BlockInfoPadded(const Params &params,\n                               const int bidb,\n                               const int bidh,\n                               const int tidx)\n        : bidb(bidb), bidh(bidh), h(params.h) {\n\n        // The block index.\n        sum_s = params.cu_seqlens[bidb];\n        actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;\n        bidx = sum_s * params.h + bidh;\n\n        tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;\n    }\n\n    __device__ bool stop_early() const {\n        return actual_seqlen == 0;\n    }\n\n    int actual_seqlen;\n    int bidx;\n    int sum_s;\n    int bidh;\n    int bidb;\n    int tidx_global;\n    int h;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int CHUNKS, typename Cta_tile> \nstruct Noloop_traits{\n    // Interpretation of Cta_tile dims, i.e. Cta_tile_p:\n    enum{ STEP = Cta_tile::M };\n    enum{ SEQLEN = Cta_tile::N };\n\n    // The size of the subsequence this CTA is processing\n    enum { SUBSEQ = SEQLEN / CHUNKS };\n    static_assert(SUBSEQ * CHUNKS == SEQLEN);\n\n    // The number of steps to process the subsequence\n    enum { NUM_STEPS = SUBSEQ / STEP };\n    static_assert(NUM_STEPS  * Cta_tile::M == SUBSEQ);\n\n    inline __device__ Noloop_traits(const int bidc) \n        : loop_offset_(NUM_STEPS * bidc)\n        , bidc_(bidc) {\n    }\n\n    template<typename ... Tiles> \n    inline __device__ void move_all(Tiles & ... tiles) const {\n        using expand_type = int[];\n        for( int s = 0; s < loop_offset_; s++ ) {\n            expand_type{ (tiles.move(), 0)... };\n        }\n    }\n\n    inline __device__ int get_idx_dk() const {\n        //return bidc_;\n        return bidc_ * 2 + 0;\n    }\n\n    inline __device__ int get_idx_dv() const {\n        //return CHUNKS + bidc_;\n        return bidc_ * 2 + 1;\n    }\n\n    inline __device__ int offset_loop_count(const int l) {\n        // convert loop counter to position in the outer sequence\n        return (loop_offset_ + l) * STEP;\n    }\n\n    const int loop_offset_;\n    const uint32_t bidc_;\n    const int num_steps_ = NUM_STEPS;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile> \nstruct Noloop_traits<3, Cta_tile>{\n    // Interpretation of Cta_tile dims, i.e. Cta_tile_p:\n    enum{ STEP = Cta_tile::M };\n    enum{ SEQLEN = Cta_tile::N };\n\n    static_assert(STEP == 16 && SEQLEN == 512);\n\n    inline __device__ Noloop_traits(const int bidc)\n        : bidc_(bidc)\n        , num_steps_(bidc < 2 ? 11 : 10) \n        , loop_offset_(bidc * 11) {\n    }\n\n    template<typename ... Tiles> \n    inline __device__ void move_all(Tiles & ... tiles) const {\n        using expand_type = int[];\n        for( int s = 0; s < loop_offset_; s++ ) {\n            expand_type{ (tiles.move(), 0)... };\n        }\n    }\n\n    inline __device__ int get_idx_dk() const {\n        //return bidc_;\n        return bidc_ * 2 + 0;\n    }\n\n    inline __device__ int get_idx_dv() const {\n        //return CHUNKS + bidc_;\n        return bidc_ * 2 + 1;\n    }\n\n    inline __device__ int offset_loop_count(const int l) {\n        // convert loop counter to position in the outer sequence\n        return (loop_offset_ + l) * STEP;\n    }\n\n    const int loop_offset_;\n    const uint32_t bidc_;\n    const int  num_steps_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n\ninline __device__ float4 ldg128(const void *ptr) {\n    return *static_cast<const float4 *>(ptr);\n}\n\ninline __device__ void stg128(void *ptr, const float4 &data) {\n    *static_cast<float4 *>(ptr) = data;\n}\n\ntemplate<typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>\n__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out,\n                                                                     const void *__restrict__ in,\n                                                                     const int *__restrict__ cu_seqlens,\n                                                                     const int batch_size) {\n\n    enum { BYTES_PER_LDG = 16 };\n    enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };\n\n    // One CTA hidden vector for K and V\n    enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };\n    // The stride in bytes in dQKV\n    enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };\n    // The offset in bytes in dQKV to the dKV part for non-interleaved heads\n    enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };\n\n    static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); \n\n    // Size in bytes of the input tile\n    enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };\n\n    enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };\n\n    enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };\n    static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);\n\n    union Vec_t {\n        float4 raw;\n        T elt[NUM_ELTS];\n    };\n\n    // ZERO-OUT invalid positions in dQKV\n    const int total = cu_seqlens[batch_size];\n    if(blockIdx.x >= total){\n        enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };\n        enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };\n\n        const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);\n\n        char *base_ptr = static_cast<char *>(out) + blockIdx.x * OUT_STRIDE_BYTES;\n\n        for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){\n            stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);\n        }\n\n        return;\n    }\n\n    // SETUP\n    const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;\n    const char *ptr_in = static_cast<const char *>(in) + offset_in;\n\n    const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;\n    char *ptr_out = static_cast<char *>(out) + OUT_OFFSET_KV_BYTES + offset_out;\n\n    // LOAD\n\n    Vec_t local_in[CHUNKS][LDGS];\n\n    #pragma unroll\n    for( int c = 0; c < CHUNKS; c++ ) {\n        #pragma unroll\n        for( int l = 0; l < LDGS; l++ ) {\n            int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;\n            local_in[c][l].raw = ldg128(ptr_in + offset);\n        }\n    }\n\n    // UNPACK\n    float acc[LDGS][NUM_ELTS];\n\n    #pragma unroll\n    for( int l = 0; l < LDGS; l++ ) {\n        #pragma unroll\n        for( int e = 0; e < NUM_ELTS; e++ ) {\n            acc[l][e] = float(local_in[0][l].elt[e]);\n        }\n    }\n\n    // COMPUTE\n    #pragma unroll\n    for( int c = 1; c < CHUNKS; c++ ) {\n        #pragma unroll\n        for( int l = 0; l < LDGS; l++ ) {\n            #pragma unroll\n            for( int e = 0; e < NUM_ELTS; e++ ) {\n                acc[l][e] += float(local_in[c][l].elt[e]);\n            }\n        }\n    }\n\n    // PACK\n    Vec_t local_out[LDGS];\n\n    #pragma unroll\n    for( int l = 0; l < LDGS; l++ ) {\n        #pragma unroll\n        for( int e = 0; e < NUM_ELTS; e++ ) {\n            local_out[l].elt[e] = T(acc[l][e]);\n        }\n    }\n\n    // STORE\n    #pragma unroll\n    for( int l = 0; l < LDGS; l++ ) {\n        const int offset = l * BYTES_PER_CTA;\n        stg128(ptr_out + offset, local_out[l].raw);\n    }\n}\n\nvoid fmha_run_noloop_reduce(void *out,\n                            const void *in,\n                            const int *cu_seqlens,\n                            const int hidden_size,\n                            const int batch_size,\n                            const int total,\n                            const int num_chunks,\n                            cudaStream_t stream) {\n\n    const int blocks = total;\n\n    if(hidden_size == 1024){\n\n        constexpr int HIDDEN_SIZE = 1024;\n        constexpr int THREADS = 256;\n\n        if( num_chunks == 2 ) {\n            fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);\n        } else if( num_chunks == 3 ) {\n            fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);\n        } else {\n            assert(false && \"Unsupported num_chunks\");\n        }\n\n    }else{\n        assert(false && \"Unsupported hidden_size\");\n    }\n\n    FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/fmha/src/fmha_utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime_api.h>\n#include <cuda_fp16.h>\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define FMHA_CHECK_CUDA( call )                                                                    \\\n    do {                                                                                           \\\n        cudaError_t status_ = call;                                                                \\\n        if( status_ != cudaSuccess ) {                                                             \\\n            fprintf( stderr,                                                                       \\\n                     \"CUDA error (%s:%d): %s\\n\",                                                   \\\n                     __FILE__,                                                                     \\\n                     __LINE__,                                                                     \\\n                     cudaGetErrorString( status_ ) );                                              \\\n            exit( 1 );                                                                             \\\n        }                                                                                          \\\n    } while( 0 )\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nenum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {\n    if( dtype == DATA_TYPE_FP16 ) {\n        half x = __float2half_rn( norm );\n        uint16_t h = reinterpret_cast<const uint16_t &>( x );\n        ushort2 h2 = { h, h };\n        alpha = reinterpret_cast<const uint32_t &>( h2 );\n    } else if( dtype == DATA_TYPE_FP32 ) {\n        alpha = reinterpret_cast<const uint32_t &>( norm );\n    } else if( dtype == DATA_TYPE_INT32 ) {\n        int32_t inorm = static_cast<int32_t>( norm );\n        alpha = reinterpret_cast<const uint32_t &>( inorm );\n    } else {\n        assert( false );\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {\n    switch( dtype ) {\n    case DATA_TYPE_FP32:\n        return n * 4;\n    case DATA_TYPE_FP16:\n        return n * 2;\n    case DATA_TYPE_INT32:\n        return n * 4;\n    case DATA_TYPE_INT8:\n        return n;\n    default:\n        assert( false );\n        return 0;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/batch_norm.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCNumerics.cuh>\n\n#include \"THC/THC.h\"\n\n#include \"batch_norm.h\"\n\n#include <cuda.h>\n\n#include \"compat.h\"\n\n#define cudaCheckErrors(msg) \\\n    do { \\\n        cudaError_t __err = cudaGetLastError(); \\\n        if (__err != cudaSuccess) { \\\n            fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", \\\n                msg, cudaGetErrorString(__err), \\\n                __FILE__, __LINE__); \\\n            fprintf(stderr, \"*** FAILED - ABORTING\\n\"); \\\n            exit(1); \\\n        } \\\n    } while (0)\n\nstatic size_t round_up_to_multiple(size_t x, int multiple) {\n  return ((x + multiple - 1) / multiple) * multiple;\n}\n\n// TODO: Stop manually allocating CUDA memory; allocate an ATen byte\n// tensor instead.\nstruct Workspace {\n  Workspace(size_t size) : size(size), data(NULL) {\n    data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);\n  }\n  Workspace(const Workspace&) = delete;\n  Workspace(Workspace&&) = default;\n  Workspace& operator=(Workspace&&) = default;\n  ~Workspace() {\n    if (data) {\n      THCudaFree(at::globalContext().lazyInitCUDA(), data);\n    }\n  }\n\n  size_t size;\n  void* data;\n};\n\n// Return {y}\nat::Tensor nhwc_bn_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void * my_data,\n                       void * pair_data,\n                       void * pair_data2,\n                       void * pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNorm *bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return y;\n}\n\nat::Tensor nhwc_bn_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNorm *bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwdInference(stream, fuse_relu);\n\n  return y;\n\n}\n\nstd::vector<at::Tensor> nhwc_bn_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void * my_data,\n                       void * pair_data, \n                       void * pair_data2, \n                       void * pair_data3, \n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n  // shape\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // outputs\n  at::Tensor x_grad, scale_grad, bias_grad;\n\n  // Allocate outputs\n  x_grad = at::empty_like(x);\n  scale_grad = at::empty_like(scale);\n  bias_grad = at::empty_like(bias);\n\n  // Create wrapper\n  NhwcBatchNorm *bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             x_grad.DATA_PTR<at::Half>(),\n                             nullptr,\n                             dy.DATA_PTR<at::Half>());\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};\n}\n\nint nhwc_bn_fwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n\n    //max occupancy supported by the code is 2\n    return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2);\n}\n\nint nhwc_bn_bwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n    \n    //max occupancy supported by the code is 2\n    return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2);\n}\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/batch_norm.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm.h\n * \\brief CUDA NHWC Batch Normalization code\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer\n*/\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n\n#include <cudnn.h>\n\n#include <algorithm>\n#include <vector>\n#include <string>\n\n#include \"nhwc_batch_norm_kernel.h\"\n#include \"cuda_utils.h\"\n\n\n#define VERBOSE_DEFAULT false\n\nclass NhwcBatchNorm {\n public:\n  NhwcBatchNorm() {\n    name_ = \"nhwc_batchnorm\";\n    createTensorDescriptor(&X_tensor_desc_);\n    createTensorDescriptor(&Y_tensor_desc_);\n  }\n\n  ~NhwcBatchNorm() {\n    destroyTensorDescriptor(X_tensor_desc_);\n    destroyTensorDescriptor(Y_tensor_desc_);\n  }\n\n  void die() {\n    std::cerr << \"batchnorm not initialized\" << std::endl;\n    exit(-1);\n  }\n\n  void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void fwdInference(cudaStream_t stream, bool use_relu);\n  dim3 calc_fwd_grid(int *loop, const int grid_dim_x);\n  dim3 calc_bwd_grid(int *loop, const int grid_dim_x);\n\n  void setInputDescriptor(const cudnnTensorFormat_t format,\n                                  const cudnnDataType_t     data_type,\n                                  int n, int c, int h, int w, int bn_group) {\n    m_ = n * h * w;\n    int m_bn_adjusted = m_ * bn_group;\n    c_ = c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    svar_inv_count_ = 1.f / m_bn_adjusted;\n    // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).\n    int divisor = m_bn_adjusted - 1;\n    // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.\n    rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;\n    setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  void setOutputDescriptor(const cudnnTensorFormat_t format,\n                                   const cudnnDataType_t     data_type,\n                                   int n, int c, int h, int w) {\n    setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  const std::vector<size_t> numWorkspaceBytes() const;\n\n  void setWorkspacePointers(\n      const std::vector<void*>&  workspace,\n      const std::vector<size_t>& num_workspace_bytes);\n\n  void setInputOutputPointers(void* X, void* dX, void* Y, void *dY) {\n    X_ = X;\n    dX_  = dX;\n    Y_   = Y;\n    dY_  = dY;\n  }\n\n  // Sets the pointers for the scale and weight (in that order) data and derivative buffers.\n  void setWeightPointers(const std::vector<void*>& weight_pointers,\n                                 const std::vector<void*>& deriv_pointers) {\n    assert(weight_pointers.size() == 2);\n    assert(deriv_pointers.size()  == 2);\n    scale_  = static_cast<float*>(weight_pointers[0]);\n    bias_   = static_cast<float*>(weight_pointers[1]);\n    dscale_ = static_cast<float*>(deriv_pointers[0]);\n    dbias_  = static_cast<float*>(deriv_pointers[1]);\n  }\n\n  // Sets the pointers for the population mean and variance buffers, in that order.\n  void setParameterPointers(const std::vector<void*>& param_pointers) {\n    assert(param_pointers.size() == 2);\n    population_mean_     = static_cast<float*>(param_pointers[0]);\n    population_variance_ = static_cast<float*>(param_pointers[1]);\n  }\n\n  void setConstants(const double exp_avg_factor, const double eps) {\n    exp_avg_factor_ = exp_avg_factor;\n    eps_ = eps;\n  }\n\n  void processCudnnStatus(const cudnnStatus_t& status,\n                          const std::string& string = std::string(),\n                          bool verbose = VERBOSE_DEFAULT) {\n    if (status != CUDNN_STATUS_SUCCESS)\n      LOG(FATAL) << string << \" \" << cudnnGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudnnGetErrorString(status);\n  }\n\n  void checkCudaStatus(const std::string& string = std::string(),\n                       bool verbose = VERBOSE_DEFAULT) {\n    cudaError_t status = cudaGetLastError();\n    if (status != cudaSuccess)\n      LOG(FATAL) << string << \" \" << cudaGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudaGetErrorString(status);\n  }\n\n  size_t size_retired_ctas(int grid_y) const {\n    // Note that the value of max_grid_y to handle known GPUs is about 160.\n    const int max_grid_y = 1024;\n    if (grid_y > max_grid_y)\n      LOG(INFO) << \"GPU capabilities exceeds assumptions.\";\n    const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);\n    // Since the region will be initialized once and used for many kernels,\n    // the idea is to return an ample size that will cover all uses.\n    return retired_cta_bytes;\n  }\n\n  cudnnTensorDescriptor_t  X_tensor_desc_ = nullptr;\n  cudnnTensorDescriptor_t  Y_tensor_desc_ = nullptr;\n\n  void*  X_ = nullptr;\n  void* dX_ = nullptr;\n  void*  Y_ = nullptr;\n  void* dY_ = nullptr;\n\n  // Learned scale and bias weights.\n  float* scale_  = nullptr;\n  float* dscale_ = nullptr;\n  float* bias_   = nullptr;\n  float* dbias_  = nullptr;\n\n  // Computed population mean and variance parameters.\n  float* population_mean_     = nullptr;\n  float* population_variance_ = nullptr;\n\n  // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).\n  float* minibatch_mean_     = nullptr;\n  float* minibatch_variance_ = nullptr;\n\n  int m_ = 0;  // Number of values per channel that BN is normalizing.\n  int c_ = 0;  // Number of channels over which BN is normalizing.\n\n  float svar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get saved variance\n  float rvar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get running variance\n\n  double exp_avg_factor_ = 0.;\n  double eps_            = 0.;\n  std::string name_;\n\n private:\n  void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,\n                           cudnnTensorFormat_t format,\n                           cudnnDataType_t     data_type,\n                           int n, int c, int h, int w) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);\n    processCudnnStatus(status, \"set tensor descriptor\");\n  }\n\n  void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnCreateTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"create tensor_descriptor\");\n  }\n\n  void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnDestroyTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"destroy tensor_descriptor\");\n  }\n\n protected:\n  float *partial_sums_ = nullptr;\n  int *partial_counts_ = nullptr;\n  int *retired_ctas_   = nullptr;\n\n  void _setFwdParams(NhwcBatchNormFwdParams *params) const;\n  void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;\n  void _setBwdParams(NhwcBatchNormBwdParams *params) const;\n\n  // @todo: ability to configure these?\n  // Kernel params\n  static const int USE_ONLINE_APPROACH = 1;\n  static const int THREADS_PER_CTA = 512;\n  static const int THREADS_PER_PIXEL = 16;\n  static const int C_ELEMENTS_PER_CTA = 64;\n  static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;\n  static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;\n\n  typedef uint16_t StorageType;\n  //typedef float StorageType;\n  // increasing this to 6 causes spills in fwd kernel!\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;\n  static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;\n  static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;\n\n  static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_FWD;\n  static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_BWD;\n  static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;\n\n  // Derived params\n  static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*sizeof(StorageType);\n  static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*2*sizeof(StorageType);\n  static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD;\n  static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_BWD;\n  static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD_INFERENCE;\n\n  // max grid.y in case of group bn is limited by exchange buffer size\n  static const int MAX_GBN_BLOCK_Y = 256;\n\n  // Helper function to launch the forward kernel.\n\n  // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel\n  // version that was compiled with that occupancy in its launch bounds.  This way, we avoid\n  // needless register spills.\n  void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,\n                                dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {\n\n#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\"; \\\n        auto fwd_func = nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" fwd ser coop kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" fwd ser coop kernel\"); \\\n    } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1 && use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, true, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, true, false, 1, coop);\n    } else if (outer_loops == 1 && !use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, false, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, false, false, 1, coop);\n    } else if (use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, true, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, true, false, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, false, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, false, false, 1, coop);\n    }\n#undef LAUNCH_FWD_KERNEL\n  }\n\n  // Helper function to launch the backward kernel.\n\n  void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,\n                                dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {\n#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\"; \\\n        auto bwd_func = nhwc_batch_norm_bwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" bwd coop serial kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<BWD_FUNC>(bwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" bwd coop serial kernel\"); \\\n    } while (0)\n\n#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\"; \\\n        auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" bwd-relu coop serial kernel\"); \\\n    } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1 && use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_RELU_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_RELU_KERNEL(1, 1, coop);\n    } else if (outer_loops == 1 && !use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_KERNEL(1, 1, coop);\n    } else if (use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_RELU_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_RELU_KERNEL(0, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_BWD_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_KERNEL(0, 1, coop);\n    }\n#undef LAUNCH_BWD_KERNEL\n  }\n\n public:\n\n  // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n\n  // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n};\n\nconst std::vector<size_t> NhwcBatchNorm::numWorkspaceBytes() const {\n  assert(c_ > 0);\n\n  // choose the max memory required between fwd/bwd passes\n  int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);\n  int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);\n  int grid_x = max(grid_x_fwd, grid_x_bwd);\n  int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  const size_t num_mean_bytes     = c_ * sizeof(float);\n  const size_t num_variance_bytes = num_mean_bytes;\n  const size_t size_sums          = grid_y*grid_x*THREADS_PER_PIXEL*\\\n      ELEMENTS_PER_LDG*2*sizeof(float);\n  const size_t size_counts        = grid_y*grid_x*sizeof(int);\n\n  return {num_mean_bytes, num_variance_bytes,\n          size_retired_ctas(grid_y), size_sums, size_counts};\n}\n\nvoid NhwcBatchNorm::setWorkspacePointers(\n      const std::vector<void*>& workspace,\n      const std::vector<size_t>& num_workspace_bytes) {\n  assert(workspace.size() == 5);\n  assert(num_workspace_bytes.size() == 5);\n\n  minibatch_mean_     = static_cast<float*>(workspace[0]);\n  minibatch_variance_ = static_cast<float*>(workspace[1]);\n  retired_ctas_       = static_cast<int*>(workspace[2]);\n  partial_sums_       = static_cast<float*>(workspace[3]);\n  partial_counts_     = static_cast<int*>(workspace[4]);\n}\n\nvoid NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dst          = static_cast<uint16_t*>(Y_);\n  params->gmem_src1         = nullptr;\n  params->gmem_bias         = bias_;\n  params->gmem_scale        = scale_;\n  params->gmem_running_mean = population_mean_;\n  params->gmem_running_var  = population_variance_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->gmem_relu_bitmask = nullptr;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->rvar_inv_count    = rvar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_counts       = partial_counts_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->var_eps           = eps_;\n  params->outer_loops       = 0;\n  params->exp_avg_factor    = static_cast<float>(exp_avg_factor_);\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams\n                                                        *params) const {\n  params->gmem_src   = static_cast<uint16_t*>(X_);\n  params->gmem_dst   = static_cast<uint16_t*>(Y_);\n  params->gmem_src1  = nullptr;\n  params->gmem_bias  = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_mean  = population_mean_;\n  params->gmem_var   = population_variance_;\n  params->nhw        = m_;\n  params->c          = c_;\n  params->var_eps    = eps_;\n}\n\nvoid NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dy           = static_cast<uint16_t*>(dY_);\n  params->gmem_dst          = static_cast<uint16_t*>(dX_);\n  params->gmem_dst1         = nullptr;\n  params->gmem_relu_bitmask = nullptr;\n  params->gmem_dscale       = dscale_;\n  params->gmem_dbias        = dbias_;\n  params->gmem_scale        = scale_;\n  params->gmem_bias         = bias_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->outer_loops       = 0;\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      //      && minibatch_mean_ != nullptr\n      //      && minibatch_variance_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);\n  grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  // @todo: maybe just move this inside initialize routine?\n  NhwcBatchNormFwdInferenceParams params;\n  _setFwdInferenceParams(&params);\n\n  if (use_relu) {\n    nhwc_batch_norm_fwd_inference\n      <StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, true, false>\n    <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n    checkCudaStatus(name_ + \" fwd_inference-relu kernel\");\n  } else {\n    nhwc_batch_norm_fwd_inference\n      <StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, false>\n    <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n    checkCudaStatus(name_ + \" fwd_inference kernel\");\n  }\n}\n\ndim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\ndim3 NhwcBatchNorm::calc_bwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\nvoid NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                        const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr\n      && retired_ctas_   != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormFwdParams params;\n  _setFwdParams(&params);\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n\n  dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);\n  _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);\n}\n\nvoid NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, \n                          const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && (bias_ != nullptr || !use_relu)\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      //      && population_mean_ != nullptr\n      //      && population_variance_ != nullptr\n      && X_ != nullptr\n      && dX_ != nullptr\n      //      && Y_ != nullptr\n      && dY_ != nullptr\n      && dscale_ != nullptr\n      && dbias_ != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormBwdParams params;\n  _setBwdParams(&params);\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n  params.wgrad_coeff = 1.0 / bn_group;\n\n  dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);\n  _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCNumerics.cuh>\n\n#include \"THC/THC.h\"\n\n#include \"batch_norm_add_relu.h\"\n\n#include <cuda.h>\n\n#include \"compat.h\"\n\n//FIXME move the common stuff to common h file\n#define cudaCheckErrors(msg) \\\n    do { \\\n        cudaError_t __err = cudaGetLastError(); \\\n        if (__err != cudaSuccess) { \\\n            fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", \\\n                msg, cudaGetErrorString(__err), \\\n                __FILE__, __LINE__); \\\n            fprintf(stderr, \"*** FAILED - ABORTING\\n\"); \\\n            exit(1); \\\n        } \\\n    } while (0)\n\nstatic size_t round_up_to_multiple(size_t x, int multiple) {\n  return ((x + multiple - 1) / multiple) * multiple;\n}\n\n// TODO: Stop manually allocating CUDA memory; allocate an ATen byte\n// tensor instead.\nstruct Workspace {\n  Workspace(size_t size) : size(size), data(NULL) {\n    data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);\n  }\n  Workspace(const Workspace&) = delete;\n  Workspace(Workspace&&) = default;\n  Workspace& operator=(Workspace&&) = default;\n  ~Workspace() {\n    if (data) {\n      THCudaFree(at::globalContext().lazyInitCUDA(), data);\n    }\n  }\n\n  size_t size;\n  void* data;\n};\n\n// Return {y}\nat::Tensor nhwc_bn_addrelu_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void * my_data,\n                       void * pair_data,\n                       void * pair_data2,\n                       void * pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr,\n                             z.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n  workspace.push_back(bitmask.DATA_PTR<int32_t>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return y;\n}\n\nat::Tensor nhwc_bn_addrelu_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr,\n                             z.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwdInference(stream);\n\n  return y;\n\n}\n\nstd::vector<at::Tensor> nhwc_bn_addrelu_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void * my_data,\n                       void * pair_data, \n                       void * pair_data2, \n                       void * pair_data3, \n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n  // shape\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // outputs\n  at::Tensor x_grad, z_grad, scale_grad, bias_grad;\n\n  // Allocate outputs\n  x_grad = at::empty_like(x);\n  z_grad = at::empty_like(x);\n  scale_grad = at::empty_like(scale);\n  bias_grad = at::empty_like(bias);\n\n  // Create wrapper\n  NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             x_grad.DATA_PTR<at::Half>(),\n                             nullptr,\n                             dy.DATA_PTR<at::Half>(),\n                             nullptr,\n                             z_grad.DATA_PTR<at::Half>());\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n  workspace.push_back(bitmask.DATA_PTR<int32_t>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad};\n}\n\nint nhwc_bn_addrelu_fwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n    \n    //max occupancy supported by the code is 2\n    return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2);\n}\n\nint nhwc_bn_addrelu_bwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n\n    //max occupancy supported by the code is 2\n    return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2);\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/batch_norm_add_relu.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm_add_relu.h\n * \\brief CUDA NHWC Batch Normalization code with fused addition\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer\n*/\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n\n#include <cudnn.h>\n\n#include <algorithm>\n#include <vector>\n#include <string>\n\n#include \"nhwc_batch_norm_kernel.h\"\n#include \"cuda_utils.h\"\n\n\n#define VERBOSE_DEFAULT false\n\nclass NhwcBatchNormAddRelu {\n public:\n  NhwcBatchNormAddRelu() {\n    name_ = \"nhwc_batchnormaddrelu\";\n    createTensorDescriptor(&X_tensor_desc_);\n    createTensorDescriptor(&Y_tensor_desc_);\n  }\n\n  ~NhwcBatchNormAddRelu() {\n    destroyTensorDescriptor(X_tensor_desc_);\n    destroyTensorDescriptor(Y_tensor_desc_);\n  }\n\n  void die() {\n    std::cerr << \"batchnormaddrelu not initialized\" << std::endl;\n    exit(-1);\n  }\n\n  void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void fwdInference(cudaStream_t stream);\n  dim3 calc_fwd_grid(int *loop, const int grid_dim_x);\n  dim3 calc_bwd_grid(int *loop, const int grid_dim_x);\n\n  void setInputDescriptor(const cudnnTensorFormat_t format,\n                                  const cudnnDataType_t     data_type,\n                                  int n, int c, int h, int w, int bn_group) {\n    m_ = n * h * w;\n    int m_bn_adjusted = m_ * bn_group;\n    c_ = c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    svar_inv_count_ = 1.f / m_bn_adjusted;\n    // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).\n    int divisor = m_bn_adjusted - 1;\n    // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.\n    rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;\n    setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  void setOutputDescriptor(const cudnnTensorFormat_t format,\n                                   const cudnnDataType_t     data_type,\n                                   int n, int c, int h, int w) {\n    setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  const std::vector<size_t> numWorkspaceBytes() const;\n\n  void setWorkspacePointers(\n      const std::vector<void*>&  workspace,\n      const std::vector<size_t>& num_workspace_bytes);\n\n  void setInputOutputPointers(void* X, void* dX, void* Y, void *dY, void* addend, void* dAddend) {\n    X_ = X;\n    dX_  = dX;\n    Y_   = Y;\n    dY_  = dY;\n    addend_   = addend;\n    dAddend_  = dAddend;\n  }\n\n  // Sets the pointers for the scale and weight (in that order) data and derivative buffers.\n  void setWeightPointers(const std::vector<void*>& weight_pointers,\n                                 const std::vector<void*>& deriv_pointers) {\n    assert(weight_pointers.size() == 2);\n    assert(deriv_pointers.size()  == 2);\n    scale_  = static_cast<float*>(weight_pointers[0]);\n    bias_   = static_cast<float*>(weight_pointers[1]);\n    dscale_ = static_cast<float*>(deriv_pointers[0]);\n    dbias_  = static_cast<float*>(deriv_pointers[1]);\n  }\n\n  // Sets the pointers for the population mean and variance buffers, in that order.\n  void setParameterPointers(const std::vector<void*>& param_pointers) {\n    assert(param_pointers.size() == 2);\n    population_mean_     = static_cast<float*>(param_pointers[0]);\n    population_variance_ = static_cast<float*>(param_pointers[1]);\n  }\n\n  void setConstants(const double exp_avg_factor, const double eps) {\n    exp_avg_factor_ = exp_avg_factor;\n    eps_ = eps;\n  }\n\n  void processCudnnStatus(const cudnnStatus_t& status,\n                          const std::string& string = std::string(),\n                          bool verbose = VERBOSE_DEFAULT) {\n    if (status != CUDNN_STATUS_SUCCESS)\n      LOG(FATAL) << string << \" \" << cudnnGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudnnGetErrorString(status);\n  }\n\n  void checkCudaStatus(const std::string& string = std::string(),\n                       bool verbose = VERBOSE_DEFAULT) {\n    cudaError_t status = cudaGetLastError();\n    if (status != cudaSuccess)\n      LOG(FATAL) << string << \" \" << cudaGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudaGetErrorString(status);\n  }\n\n  size_t size_retired_ctas(int grid_y) const {\n    // Note that the value of max_grid_y to handle known GPUs is about 160.\n    const int max_grid_y = 1024;\n    if (grid_y > max_grid_y)\n      LOG(INFO) << \"GPU capabilities exceeds assumptions.\";\n    const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);\n    // Since the region will be initialized once and used for many kernels,\n    // the idea is to return an ample size that will cover all uses.\n    return retired_cta_bytes;\n  }\n\n  cudnnTensorDescriptor_t  X_tensor_desc_ = nullptr;\n  cudnnTensorDescriptor_t  Y_tensor_desc_ = nullptr;\n\n  void*  X_ = nullptr;\n  void* dX_ = nullptr;\n  void*  Y_ = nullptr;\n  void* dY_ = nullptr;\n  void*  addend_ = nullptr;\n  void* dAddend_ = nullptr;\n\n  // Learned scale and bias weights.\n  float* scale_  = nullptr;\n  float* dscale_ = nullptr;\n  float* bias_   = nullptr;\n  float* dbias_  = nullptr;\n\n  // Computed population mean and variance parameters.\n  float* population_mean_     = nullptr;\n  float* population_variance_ = nullptr;\n\n  // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).\n  float* minibatch_mean_     = nullptr;\n  float* minibatch_variance_ = nullptr;\n\n  int m_ = 0;  // Number of values per channel that BN is normalizing.\n  int c_ = 0;  // Number of channels over which BN is normalizing.\n\n  float svar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get saved variance\n  float rvar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get running variance\n\n  double exp_avg_factor_ = 0.;\n  double eps_            = 0.;\n  std::string name_;\n\n private:\n  void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,\n                           cudnnTensorFormat_t format,\n                           cudnnDataType_t     data_type,\n                           int n, int c, int h, int w) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);\n    processCudnnStatus(status, \"set tensor descriptor\");\n  }\n\n  void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnCreateTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"create tensor_descriptor\");\n  }\n\n  void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnDestroyTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"destroy tensor_descriptor\");\n  }\n\n protected:\n  float *partial_sums_ = nullptr;\n  int *partial_counts_ = nullptr;\n  int *retired_ctas_   = nullptr;\n  unsigned int *relu_bitmask_ = nullptr;\n\n  void _setFwdParams(NhwcBatchNormFwdParams *params) const;\n  void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;\n  void _setBwdParams(NhwcBatchNormBwdParams *params) const;\n\n  // @todo: ability to configure these?\n  // Kernel params\n  static const int USE_ONLINE_APPROACH = 1;\n  static const int THREADS_PER_CTA = 512;\n  static const int THREADS_PER_PIXEL = 16;\n  static const int C_ELEMENTS_PER_CTA = 64;\n  static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;\n  static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;\n\n  typedef uint16_t StorageType;\n  // increasing this to 6 causes spills in fwd kernel!\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;\n  static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;\n  static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;\n\n  static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_FWD;\n  static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_BWD;\n  static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;\n\n  // Derived params\n  static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*sizeof(StorageType);\n  static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*2*sizeof(StorageType);\n  static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD;\n  static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_BWD;\n  static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD_INFERENCE;\n\n  // max grid.y in case of group bn is limited by exchange buffer size\n  static const int MAX_GBN_BLOCK_Y = 256;\n\n  // Helper function to launch the forward kernel.\n\n  // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel\n  // version that was compiled with that occupancy in its launch bounds.  This way, we avoid\n  // needless register spills.\n  void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,\n                                dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {\n#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \\\n            \"Nhwc batchnormaddrelu kernel smem too big.\"; \\\n        auto fwd_func = nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" fwd ser coop kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" fwd ser coop kernel\"); \\\n    } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, false, true, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, false, true, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, false, true, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, false, true, 1, coop);\n    }\n#undef LAUNCH_FWD_KERNEL\n  }\n\n  // Helper function to launch the backward kernel.\n\n  void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,\n                                dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {\n#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \\\n            \"Nhwc batchnormaddrelu kernel smem too big.\"; \\\n        auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(bwd_add_relu_func, \\\n                             cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \\\n                \" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" bwd-add-relu coop serial kernel\"); \\\n  } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop);\n    }\n#undef LAUNCH_BWD_KERNEL\n  }\n\n public:\n  // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n\n  // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n};\n\nconst std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {\n  assert(c_ > 0);\n\n  // choose the max memory required between fwd/bwd passes\n  int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);\n  int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);\n  int grid_x = max(grid_x_fwd, grid_x_bwd);\n  int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  const size_t num_mean_bytes     = c_ * sizeof(float);\n  const size_t num_variance_bytes = num_mean_bytes;\n\n  int elems_per_group = ((m_ + 31) & ~31) * 2;\n  int group_count = div_up(c_, C_ELEMENTS_PER_CTA);\n  const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int);\n\n  const size_t size_sums          = grid_y*grid_x*THREADS_PER_PIXEL*\\\n      ELEMENTS_PER_LDG*2*sizeof(float);\n  const size_t size_counts        = grid_y*grid_x*sizeof(int);\n\n  return {num_mean_bytes, num_variance_bytes, bitmask_bytes,\n          size_retired_ctas(grid_y), size_sums, size_counts};\n}\n\nvoid NhwcBatchNormAddRelu::setWorkspacePointers(\n      const std::vector<void*>& workspace,\n      const std::vector<size_t>& num_workspace_bytes) {\n  assert(workspace.size() == 6);\n  assert(num_workspace_bytes.size() == 6);\n\n  minibatch_mean_     = static_cast<float*>(workspace[0]);\n  minibatch_variance_ = static_cast<float*>(workspace[1]);\n  relu_bitmask_       = static_cast<unsigned int*>(workspace[2]);\n  retired_ctas_       = static_cast<int*>(workspace[3]);\n  partial_sums_       = static_cast<float*>(workspace[4]);\n  partial_counts_     = static_cast<int*>(workspace[5]);\n}\n\nvoid NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dst          = static_cast<uint16_t*>(Y_);\n  params->gmem_src1         = static_cast<uint16_t*>(addend_);\n  params->gmem_bias         = bias_;\n  params->gmem_scale        = scale_;\n  params->gmem_running_mean = population_mean_;\n  params->gmem_running_var  = population_variance_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->gmem_relu_bitmask = relu_bitmask_;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->rvar_inv_count    = rvar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_counts       = partial_counts_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->var_eps           = eps_;\n  params->outer_loops       = 0;\n  params->exp_avg_factor    = static_cast<float>(exp_avg_factor_);\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams\n                                                        *params) const {\n  params->gmem_src   = static_cast<uint16_t*>(X_);\n  params->gmem_dst   = static_cast<uint16_t*>(Y_);\n  params->gmem_src1  = static_cast<uint16_t*>(addend_);\n  params->gmem_bias  = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_mean  = population_mean_;\n  params->gmem_var   = population_variance_;\n  params->nhw        = m_;\n  params->c          = c_;\n  params->var_eps    = eps_;\n}\n\nvoid NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dy           = static_cast<uint16_t*>(dY_);\n  params->gmem_dst          = static_cast<uint16_t*>(dX_);\n  params->gmem_dst1         = static_cast<uint16_t*>(dAddend_);\n  params->gmem_relu_bitmask = relu_bitmask_;\n  params->gmem_dscale       = dscale_;\n  params->gmem_dbias        = dbias_;\n  params->gmem_scale        = scale_;\n  params->gmem_bias         = bias_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->outer_loops       = 0;\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      //      && minibatch_mean_ != nullptr\n      //      && minibatch_variance_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      && addend_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);\n  grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  // @todo: maybe just move this inside initialize routine?\n  NhwcBatchNormFwdInferenceParams params;\n  _setFwdInferenceParams(&params);\n\n  nhwc_batch_norm_fwd_inference\n    <StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, true>\n  <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n  checkCudaStatus(name_ + \" fwd_inference-relu kernel\");\n}\n\ndim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\ndim3 NhwcBatchNormAddRelu::calc_bwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\nvoid NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                               const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      && relu_bitmask_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      && addend_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr\n      && retired_ctas_   != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormFwdParams params;\n  _setFwdParams(&params);\n\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n\n  dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);\n  _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);\n}\n\nvoid NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                                 const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      && relu_bitmask_ != nullptr\n      //      && population_mean_ != nullptr\n      //      && population_variance_ != nullptr\n      && X_ != nullptr\n      && dX_ != nullptr\n      //      && Y_ != nullptr\n      && dY_ != nullptr\n      && dAddend_ != nullptr\n      && dscale_ != nullptr\n      && dbias_ != nullptr\n      && retired_ctas_   != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormBwdParams params;\n  _setBwdParams(&params);\n\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n  params.wgrad_coeff = 1.0 / bn_group;\n\n  dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);\n  _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/cuda_utils.h",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#ifndef CUDA_UTILS_H\n#define CUDA_UTILS_H\n\nnamespace at {\nnamespace cuda {\n\nnamespace utils {\n\nstatic inline int MaxSharedMemoryPerMultiprocessor(int device_id) {\n    return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;\n}\n\n\n}\n}\n}\n\n\n#endif\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/interface.cpp",
    "content": "#include <pybind11/pybind11.h>\n#include <pybind11/numpy.h>\n#include <pybind11/stl.h>\n\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <ATen/ArrayRef.h>\n#include <ATen/ScalarType.h>\n#include \"ATen/Scalar.h\"\n#ifndef VERSION_GE_1_1\n#include \"ATen/Type.h\"\n#endif\n#include \"ATen/Tensor.h\"\n#include \"ATen/Storage.h\"\n#include \"ATen/Generator.h\"\n\n\nnamespace py = pybind11;\n\nint64_t get_buffer_size(\n                       const int bn_sync_steps);\n\nvoid* get_data_ptr(\n                       const at::Tensor& data);\n\nvoid* get_remote_data_ptr(\n                       const at::Tensor& handle,\n                       const int64_t offset);\n\nvoid close_remote_data(\n                       const at::Tensor& handle);\n\nat::Tensor nhwc_bn_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nat::Tensor nhwc_bn_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu);\n\nstd::vector<at::Tensor> nhwc_bn_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nat::Tensor nhwc_bn_addrelu_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nat::Tensor nhwc_bn_addrelu_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon);\n\nstd::vector<at::Tensor> nhwc_bn_addrelu_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nint nhwc_bn_fwd_occupancy();\nint nhwc_bn_bwd_occupancy();\n\nint nhwc_bn_addrelu_fwd_occupancy();\nint nhwc_bn_addrelu_bwd_occupancy();\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n\n  m.def(\"get_buffer_size\", &get_buffer_size, \"get_buffer_size\");\n  m.def(\"get_data_ptr\", &get_data_ptr, \"get_data_ptr\");\n  m.def(\"get_remote_data_ptr\", &get_remote_data_ptr, \"get_remote_data_ptr\");\n  m.def(\"close_remote_data\", &close_remote_data, \"close_remote_data\");\n\n  m.def(\"bn_fwd_nhwc\", &nhwc_bn_fwd_train, \"bn_fwd_nhwc\");\n  m.def(\"bn_fwd_eval_nhwc\", &nhwc_bn_fwd_eval, \"bn_fwd_eval_nhwc\");\n  m.def(\"bn_bwd_nhwc\", &nhwc_bn_bwd, \"bn_bwd_nhwc\");\n\n  m.def(\"bn_fwd_nhwc_occupancy\", &nhwc_bn_fwd_occupancy, \"bn_fwd_nhwc_occupancy\");\n  m.def(\"bn_bwd_nhwc_occupancy\", &nhwc_bn_bwd_occupancy, \"bn_bwd_nhwc_occupancy\");\n\n  m.def(\"bn_addrelu_fwd_nhwc\", &nhwc_bn_addrelu_fwd_train, \"bn_addrelu_fwd_nhwc\");\n  m.def(\"bn_addrelu_fwd_eval_nhwc\", &nhwc_bn_addrelu_fwd_eval, \"bn_addrelu_fwd_eval_nhwc\");\n  m.def(\"bn_addrelu_bwd_nhwc\", &nhwc_bn_addrelu_bwd, \"bn_addrelu_bwd_nhwc\");\n\n  m.def(\"bn_addrelu_fwd_nhwc_occupancy\", &nhwc_bn_addrelu_fwd_occupancy, \"bn_addrelu_fwd_nhwc_occupancy\");\n  m.def(\"bn_addrelu_bwd_nhwc_occupancy\", &nhwc_bn_addrelu_bwd_occupancy, \"bn_addrelu_bwd_nhwc_occupancy\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/ipc.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCNumerics.cuh>\n\n#include \"THC/THC.h\"\n\n#include <cuda.h>\n\n#include \"compat.h\"\n\n\n#define cudaCheckErrors(msg) \\\n    do { \\\n        cudaError_t __err = cudaGetLastError(); \\\n        if (__err != cudaSuccess) { \\\n            fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", \\\n                msg, cudaGetErrorString(__err), \\\n                __FILE__, __LINE__); \\\n            fprintf(stderr, \"*** FAILED - ABORTING\\n\"); \\\n            exit(1); \\\n        } \\\n    } while (0)\n\ntemplate<>\nstruct std::hash<cudaIpcMemHandle_t> {\n  size_t operator() (const cudaIpcMemHandle_t& handle) const {\n    size_t hash = 0;\n    uint8_t* ptr = (uint8_t*)&handle;\n    assert(sizeof(uint8_t) == 1);\n    for (int i=0; i<sizeof(cudaIpcMemHandle_t); i++) {\n      hash += *ptr;\n      ptr++;\n    }\n    return hash;\n  }\n};\n\ntemplate<>\nstruct std::equal_to<cudaIpcMemHandle_t> {\n  bool operator() (const cudaIpcMemHandle_t &lhs,\n                             const cudaIpcMemHandle_t &rhs) const {\n    return (std::memcmp((void*) &lhs,\n                        (void*) &rhs,\n                        sizeof(cudaIpcMemHandle_t)) == 0);\n  }\n};\n\nnamespace {\n\nnamespace gpuipc {\n//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h\n// The number of threads per pixel.\nconst int THREADS_PER_PIXEL = 16;\n// The number of elements per ldg.\nconst int ELEMENTS_PER_LDG = 4;\n// The number of reducing ops, each uses its own space : mean, var, dscale, dbias\nconst int REDUCE_OPS = 4;\n// Maximum block.y supported - limited due to buffer allocation\nconst int MAX_BLOCK_Y = 256;\nconst int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;\nconst int BYTES_PER_ELEM = 4;\n// Buffer size per sync step\nconst int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*2*ELEMENTS_PER_LDG*BYTES_PER_ELEM;\n};\n\nclass IpcMemHandleRegistry {\npublic:\n  void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {\n    if (registry_.count(handle) == 0) {\n      registry_.insert(std::make_pair(handle, RegistryEntry()));\n      registry_[handle].dev_ptr = ipcOpenMem(handle);\n    }\n    registry_[handle].ref_count++;\n    return (((uint8_t*)registry_[handle].dev_ptr) + offset);\n  }\n\n  void releasePtr(const cudaIpcMemHandle_t& handle) {\n    if (registry_.count(handle) == 0) {\n    }\n    if (--registry_[handle].ref_count == 0) {\n      ipcCloseMem(registry_[handle].dev_ptr);\n      registry_.erase(handle);\n    }\n  }\n\n  struct RegistryEntry {\n    void* dev_ptr;\n    int   ref_count;\n    RegistryEntry() : dev_ptr(NULL) , ref_count(0) {}\n  };\n\nprotected:\n  std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;\n\n  void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {\n    void *data;\n    cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);\n    cudaCheckErrors(\"ipc init\");\n    return data;\n  }\n\n  void ipcCloseMem(void* dev_ptr) {\n    cudaIpcCloseMemHandle(dev_ptr);\n    cudaCheckErrors(\"ipc close\");\n  }\n\n};\n\n}\n\nstatic IpcMemHandleRegistry ipc_mem_registry;\n\nint64_t get_buffer_size(const int bn_sync_steps) {\n  return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES;\n}\n\nvoid* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {\n  cudaIpcMemHandle_t my_handle;\n  memcpy((unsigned char *)(&my_handle), handle.DATA_PTR<uint8_t>(), sizeof(my_handle));\n  return ipc_mem_registry.getPtr(my_handle, offset);\n}\n\nvoid close_remote_data(const at::Tensor& handle) {\n    cudaIpcMemHandle_t my_handle;\n    memcpy((unsigned char *)(&my_handle), handle.DATA_PTR<uint8_t>(), sizeof(my_handle));\n  ipc_mem_registry.releasePtr(my_handle);\n}\n\nvoid* get_data_ptr(\n                   const at::Tensor& data) {\n  return data.DATA_PTR<uint8_t>();\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm_kernel.h\n * \\brief CUDA NHWC Batch Normalization code\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer\n*/\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n\n#include <stdint.h>\n#include <algorithm>\n\n#define DEVICE_FUNCTION static inline __device__\n\n// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.\n#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN     3\n#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename T, int ELEMENTS_PER_LDG >\nstruct PackedStorage {\n    enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG };\n    typedef T Type;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int ELEMENTS_PER_LDG >\nstruct PackedStorage<uint16_t, ELEMENTS_PER_LDG> {\n    enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG/2 };\n    typedef int Type;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        uint16_t lo, hi;\n        asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(lo) : \"f\"(src[2*i+0]));\n        asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(hi) : \"f\"(src[2*i+1]));\n        asm volatile(\"mov.b32 %0, {%1, %2};\"  : \"=r\"(dst[i]) : \"h\"(lo), \"h\"(hi));\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = src[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        uint16_t lo, hi;\n        asm volatile(\"mov.b32 {%0, %1}, %2;\" : \"=h\"(lo), \"=h\"(hi) : \"r\"(src[i]));\n        asm volatile(\"cvt.f32.f16 %0, %1;\"   : \"=f\"(dst[2*i+0])   : \"h\"(lo));\n        asm volatile(\"cvt.f32.f16 %0, %1;\"   : \"=f\"(dst[2*i+1])   : \"h\"(hi));\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = src[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) {\n    dst[0] = __ldg((const int*) gmem);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) {\n    unsigned int tmp;\n    asm volatile (\"ld.global.cs.nc.s32 %0, [%1];\"  : \"=r\"(tmp) : \"l\" ((const uint *)gmem));\n    dst[0] = tmp;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) {\n    int2 tmp = __ldg((const int2*) gmem);\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) {\n    int2 tmp;\n    asm volatile (\"ld.global.cs.nc.v2.s32 {%0,%1}, [%2];\"\n        : \"=r\"(tmp.x), \"=r\"(tmp.y) : \"l\"((const int2 *)gmem));\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t *gmem) {\n    int tmp[N/2];\n    ldg(tmp, gmem);\n    to_float(dst, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t *gmem) {\n    int tmp[N/2];\n    ldg_stream(tmp, gmem);\n    to_float(dst, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) {\n    reinterpret_cast<int*>(gmem)[0] = src[0];\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) {\n    unsigned int tmp = src[0];\n    asm volatile (\"st.global.cs.s32 [%0], %1;\"\n        :: \"l\"((uint *)gmem) , \"r\"(tmp));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) {\n    reinterpret_cast<int2*>(gmem)[0] = make_int2(src[0], src[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) {\n    asm volatile (\"st.global.cs.v2.s32 [%0], {%1,%2};\"\n        :: \"l\"((uint *)gmem) , \"r\"(src[0]), \"r\"( src[1]));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[N]) {\n    int tmp[N/2];\n    from_float(tmp, src);\n    stg(gmem, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) {\n    int tmp[N/2];\n    from_float(tmp, src);\n    stg_stream(gmem, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) {\n    float2 tmp = __ldg(reinterpret_cast<const float2*>(&gmem[2*idx]));\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) {\n    float4 tmp = __ldg(reinterpret_cast<const float4*>(&gmem[4*idx]));\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n    dst[2] = tmp.z;\n    dst[3] = tmp.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) {\n    float2 tmp = *(const float2*) &smem[2*idx];\n    x[0] = tmp.x;\n    x[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) {\n    x[0] = smem[idx];\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) {\n    float4 tmp = *(const float4*) &smem[4*idx];\n    x[0] = tmp.x;\n    x[1] = tmp.y;\n    x[2] = tmp.z;\n    x[3] = tmp.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) {\n    int2 tmp = *(const int2*) &smem[2*idx];\n    x[0] = tmp.x;\n    x[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) {\n    reinterpret_cast<float2*>(&gmem[2*idx])[0] = make_float2(src[0], src[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) {\n    reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) {\n    reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) {\n    reinterpret_cast<float2*>(&smem[2*idx])[0] = make_float2(x[0], x[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) {\n    smem[idx] = x[0];\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) {\n    reinterpret_cast<float4*>(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) {\n    reinterpret_cast<int2*>(&smem[2*idx])[0] = make_int2(x[0], x[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void zero_array(int (&dst)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = 0;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void zero_array(float (&dst)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = 0.f;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] += y[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] *= y[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void scale_(float (&x)[N], float scalar) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] *= scalar;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N],\n                               const float (&scale)[N], const float (&m1)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] = bias[i] + scale[i] * (x[i] - m1[i]);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Storage>\nDEVICE_FUNCTION Storage relu(Storage in) {\n    Storage zero = (Storage)0.f;\n    return (in < zero)? zero : in;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_activation(float (&x)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] = relu(x[i]);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\ntemplate< int THREADS_PER_CTA >\nDEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,\n                                        void* params_my_data, void** params_pair_datas, int off,\n                                        const int magic,\n                                        const int sync_iters) {\n    // The size of a warp.\n    const int THREADS_PER_WARP = 32;\n    // The number of warps in a CTA.\n    const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n    // The number of threads per pixel.\n    const int THREADS_PER_PIXEL = 16;\n    // The number of elements per ldg.\n    const int ELEMENTS_PER_LDG = 4;\n    // The number of reducing ops, each uses its own space : mean, var, dscale, dbias\n    const int REDUCE_OPS = 4;\n    // Maximum block.y supported - limited due to buffer allocation\n    const int MAX_BLOCK_Y = 256;\n    const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;\n    // The warp decomposition.\n    const int warp_id = threadIdx.x / THREADS_PER_WARP;\n    const int lane_id = threadIdx.x % THREADS_PER_WARP;\n    // total size of data per sync iter\n    const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2;\n\n    #pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n    }\n\n    // The warp leaders, write to SMEM.\n    if (lane_id < THREADS_PER_PIXEL) {\n        write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);\n    }\n\n    // The data is in SMEM. Do the final reduction.\n    __syncthreads();\n\n    // The 1st warp does all the work.\n    // We do the final reduction each half-warp sequentially reduces the final values.\n    if (warp_id == 0) {\n        read_from_smem(x, smem, threadIdx.x);\n\n        #pragma unroll\n        for (int offset = 1;\n             offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {\n            float y[ELEMENTS_PER_LDG];\n            // Read the mean and variance from the other pixel.\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);\n            // Compute the updated sum.\n            add(x, y);\n        }\n\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n        }\n\n        // Make sure the data was read from SMEM.\n        __syncwarp();\n\n        // Store the final values.\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n        // probably could do it earlier, before sync\n\n        for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) {\n            //float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];\n            void* params_pair_data = params_pair_datas[sync_iter];\n\n            // skip the space consumed by previous sync iterations\n            const int xbuf_offset = sync_iter*data_total;\n            // data starts after flags, but have to skip previous\n            const int data_offset = xbuf_offset\n                                    + off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL*2\n                                    + ELEMENTS_PER_LDG*threadIdx.x*2;\n\n            // after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU\n            if (blockIdx.x == 0) {\n                volatile float * write_data =\n                    &((reinterpret_cast<float*>(params_pair_data))[data_offset]);\n\n                // write the data to memory region to be reflected to other GPU\n                asm volatile (\"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};\"\n                    :: \"l\"(write_data) , \"f\"(x[0]), \"r\"(magic), \"f\"(x[2]), \"r\"(magic));\n\n                asm volatile (\"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};\"\n                    :: \"l\"(write_data+4) , \"f\"(x[1]), \"r\"(magic), \"f\"(x[3]), \"r\"(magic));\n            }\n\n            // now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU\n            volatile float * read_data =\n                &((reinterpret_cast<float*>(params_my_data))[data_offset]);\n\n            float other[4];\n            uint32_t other_flag_a, other_flag_b;\n            do {\n                asm volatile (\"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];\"\n                    : \"=f\"(other[0]), \"=r\"(other_flag_a), \"=f\"(other[2]), \"=r\"(other_flag_b) : \"l\"(read_data));\n            } while ((other_flag_a != magic) || (other_flag_b != magic));\n\n            do {\n                asm volatile (\"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];\"\n                    : \"=f\"(other[1]), \"=r\"(other_flag_a), \"=f\"(other[3]), \"=r\"(other_flag_b) : \"l\"(read_data+4));\n            } while ((other_flag_a != magic) || (other_flag_b != magic));\n\n            add(x, other);\n        }\n        // finally, after syncing up and accounting for partial sums from\n        // other GPUs as required, write the result\n\n\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int THREADS_PER_CTA >\nDEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {\n    // The size of a warp.\n    const int THREADS_PER_WARP = 32;\n    // The number of warps in a CTA.\n    const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n    // The number of threads per pixel.\n    const int THREADS_PER_PIXEL = 8;\n    // The number of elements per ldg.\n    const int ELEMENTS_PER_LDG = 4;\n    // The warp decomposition.\n    const int warp_id = threadIdx.x / THREADS_PER_WARP;\n    const int lane_id = threadIdx.x % THREADS_PER_WARP;\n\n    #pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n        x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);\n    }\n\n    // The warp leaders, write to SMEM.\n    if (lane_id < THREADS_PER_PIXEL) {\n        write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);\n    }\n\n    // The data is in SMEM. Do the final reduction.\n    __syncthreads();\n\n    // The 1st warp does all the work.\n    // We do the final reduction each half-warp sequentially reduces the final values.\n    if (warp_id == 0) {\n        read_from_smem(x, smem, threadIdx.x);\n\n        #pragma unroll\n        for (int offset = 1;\n             offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {\n            float y[ELEMENTS_PER_LDG];\n            // Read the mean and variance from the other pixel.\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);\n            // Compute the updated sum.\n            add(x, y);\n        }\n\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n            x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);\n        }\n\n        // Make sure the data was read from SMEM.\n        __syncwarp();\n\n        // Store the final values.\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >\nDEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {\n    // The size of a warp.\n    const int THREADS_PER_WARP = 32;\n    // The number of warps in a CTA.\n    const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n    // The number of pixels computed by a single warp.\n    const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL;\n\n    // The position in the warp.\n    const int nhw_in_warp = nhw % PIXELS_PER_WARP;\n    // The C in the warp.\n    const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL;\n\n    // Store the values to shared memory.\n    write_to_smem(smem, threadIdx.x, x);\n\n    // Compute the parallel sums.\n    for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) {\n        // NOP.\n        __syncwarp();\n\n        // Read the running sum from the other thread.\n        float y[ELEMENTS_PER_LDG];\n        if (nhw_in_warp < offset) {\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);\n        }\n\n        // Compute the updated sum.\n        add(x, y);\n\n        // NOP.\n        __syncwarp();\n\n        // Update the sum in SMEM.\n        if (offset > 1 && nhw_in_warp < offset) {\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n\n    // The warps are done. Do the final reduction at the CTA level.\n    __syncthreads();\n\n    // The warp leaders, write to SMEM.\n    const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp;\n    if (nhw_in_warp == 0) {\n        write_to_smem(smem, idx, x);\n    }\n\n    // The data is in SMEM. Do the final reduction.\n    __syncthreads();\n\n    // Read the 1st element to prepare the work.\n    if (nhw < WARPS_PER_CTA/2) {\n        read_from_smem(x, smem, threadIdx.x);\n    }\n\n    // We have the running mean and running m2. Let's build the mean/var of the CTA.\n    for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) {\n        // NOP.\n        __syncwarp();\n\n        // Read the mean and variance from the other pixel.\n        float y[ELEMENTS_PER_LDG];\n        if (nhw < offset) {\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);\n        }\n\n        // Compute the updated sum.\n        add(x, y);\n\n        // NOP.\n        __syncwarp();\n\n        // Store the mean/var for the different pixels.\n        if (nhw < offset) {\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >\nstruct ParallelSums {\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {\n        parallel_sums<THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG>(smem, x, nhw);\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct ParallelSums<16, 4> {\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {\n        parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0);\n    }\n\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) {\n        parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters);\n    }\n};\n\ntemplate<>\nstruct ParallelSums<8, 4> {\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {\n        parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline int div_up(int m, int n) {\n    return (m + n - 1) / n;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// It is expected that all threads in the CTA enter this function!\nDEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) {\n\n    // Register the CTA.\n    if (threadIdx.x == 0) {\n        // Issue the membar.\n        __threadfence();\n        // Notify that the CTA is done.\n        int val_to_add = 1;\n        if (master) {\n            val_to_add = -(expected_count - 1);\n        }\n        atomicAdd(gmem_retired_ctas, val_to_add);\n    }\n\n    // Are all CTAs done?\n    if (threadIdx.x == 0) {\n        int retired_ctas = -1;\n        do {\n            __threadfence();\n            asm volatile (\"ld.global.cg.b32 %0, [%1];\"\n                : \"=r\"(retired_ctas) : \"l\"(gmem_retired_ctas));\n        } while (retired_ctas != 0);\n    }\n    __syncthreads();\n\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormFwdInferenceParams {\n    // The input/output tensors.\n    uint16_t *gmem_src, *gmem_dst, *gmem_src1;\n    // the final mean and variance as calculated during the training process\n    float *gmem_mean, *gmem_var;\n    // The bias/scale.\n    float *gmem_bias, *gmem_scale;\n    // The dimensions.\n    int nhw, c;\n    // epsilon\n    float var_eps;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int ELEMENTS_PER_LDG,\n    bool USE_RELU,\n    bool USE_ADD_RELU\n>\n__global__ __launch_bounds__(THREADS_PER_CTA)\n    void nhwc_batch_norm_fwd_inference(NhwcBatchNormFwdInferenceParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // The start position in the NHW dimension where the CTA starts.\n    const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG;\n    // Compute the NHW coordinate of the thread in the CTA.\n    const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n    // thread's starting point in NHW\n    const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG;\n\n    // The position in the C dimension where the CTA starts.\n    const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA;\n    // Compute the C coordinate of the thread in the CTA.\n    const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n    // Compute the C coordinate of the thread.\n    const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n    // Is the thread working on a valid C dimension?\n    const int is_valid_c = thread_c < params.c;\n\n    float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG];\n    float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG];\n    zero_array(mean);\n    zero_array(var);\n    zero_array(scale);\n    zero_array(bias);\n    if (is_valid_c) {\n        read_from_gmem(var, &params.gmem_var[cta_c], thread_in_cta_c);\n        read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);\n        read_from_gmem(mean, &params.gmem_mean[cta_c], thread_in_cta_c);\n        read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);\n    }\n\n    // Update the scale with the stddev and eps.\n    #pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        scale[i] *= rsqrtf(var[i] + params.var_eps);\n    }\n\n    // The base pointers for reading/writing\n    uint16_t *const gmem_src = &params.gmem_src[thread_c];\n    uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n    const uint16_t *gmem_src1 = nullptr;\n    if (USE_ADD_RELU) {\n        gmem_src1 = &params.gmem_src1[thread_c];\n    }\n\n    // apply BN\n    for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) {\n        float x_math[ELEMENTS_PER_LDG];\n        zero_array(x_math);\n        if (is_valid_c) {\n            ldg(x_math, &gmem_src[nhw*params.c]);\n        }\n\n        // Normalize and apply activation function\n        normalize(x_math, bias, scale, mean);\n        if (USE_ADD_RELU) {\n            float x1_math[ELEMENTS_PER_LDG];\n            ldg(x1_math, &gmem_src1[nhw*params.c]);\n            add(x_math, x1_math);\n            relu_activation(x_math);\n        } else if (USE_RELU) {\n            relu_activation(x_math);\n        }\n\n        if (is_valid_c) {\n            stg(&gmem_dst[nhw*params.c], x_math);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormFwdParams {\n    // The input/output tensors.\n    uint16_t *gmem_src, *gmem_dst, *gmem_src1;\n    // The bias/scale.\n    float *gmem_bias, *gmem_scale;\n    // running mean/var (refer BN API from cudnn doc)\n    float *gmem_running_mean, *gmem_running_var;\n    // saved mean/var (refer BN API from cudnn doc)\n    float *gmem_saved_mean, *gmem_saved_var;\n    // ReLU bitmask\n    unsigned int *gmem_relu_bitmask;\n    // The dimensions.\n    int nhw, c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    float svar_inv_count;\n    // factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).\n    float rvar_inv_count;\n    // The buffer to do the reduction for mean, stddev and count.\n    float *gmem_sums;\n    // The buffer to count items in the different CTAs.\n    int *gmem_counts;\n    // The counters of retired CTAs.\n    int *gmem_retired_ctas;\n    // The epsilon to apply to the computation of the variance.\n    float var_eps;\n    // outer loop count\n    int outer_loops;\n    // exponential average factor\n    float exp_avg_factor;\n    // number of CTAs along .x dimension\n    int c_blks;\n\n    void* my_data;\n    void* pair_datas[4];\n    int magic;\n    int sync_iters;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    bool USE_RELU,\n    bool USE_ADD_RELU,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // Compute the NHW coordinate of the thread in the CTA.\n    const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n        // Clamp thread_c so that we load from valid locations even if we don't use the value\n        if (!is_valid_c)\n            thread_c = params.c - 4;\n\n        // Single pass numerically stable algorithm, see:\n        // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm\n        //\n        // n = 0, mean = 0.0, M2 = 0.0\n        //\n        // for x in data:\n        //     n += 1\n        //     delta = x - mean\n        //     mean += delta/n\n        //     delta2 = x - mean\n        //     M2 += delta*delta2\n        //\n        // if n < 2:\n        //     return float('nan')\n        // else:\n        //     return M2 / (n - 1)\n\n        // Register to store the number of elements read so far.\n        float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG];\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            mean[i] = 0.f;\n            m2[i] = 0.f;\n        }\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointer to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute the mean/var across those elements.\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized, offset is evenly divisible by 32\n            int offset = (pixels_per_iteration * OUTER_LOOPS +\n                          PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31;\n            cta_nhw_regs -= offset;\n            cta_nhw_smem -= offset;\n        }\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) -\n                                 max(nhw_regs, 0), 0);\n\n            // Load the data and compute the local mean/sum and the variance.\n            if (USE_ONLINE_APPROACH) {\n                // Read the elements from memory.\n                float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                    zero_array(x_storage[i]);\n                    is_valid[i] = 0.f;\n                    if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                        if (loop_i == OUTER_LOOPS - 1) {\n                            ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        } else {\n                            ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        }\n                        is_valid[i] = 1.f;\n                    }\n                }\n\n                // Do the math.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    // Convert to float.\n                    float x_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage[i]);\n\n                    // Update the count.\n                    count += is_valid[i];\n                    // Invert the count.\n                    float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                    // Update the mean and m2 using deltas.\n                    #pragma unroll\n                    for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                        float delta0 = x_math[j] - mean[j];\n                        mean[j] += delta0 * inv_count;\n                        float delta1 = x_math[j] - mean[j];\n                        m2[j] += delta0 * delta1 * is_valid[i];\n                    }\n                }\n            } else {\n                // Read the elements from memory.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                    zero_array(x_storage[i]);\n                    if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                        if (loop_i == OUTER_LOOPS - 1) {\n                            ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        } else {\n                            ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        }\n                        count += 1.f;\n                    }\n                }\n\n                // Sum the elements in registers.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    // Convert to float.\n                    float x_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage[i]);\n\n                    // Update the mean and m2 using deltas.\n                    #pragma unroll\n                    for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                        mean[j] += x_math[j];\n                    }\n                }\n\n                // Compute the mean.\n                float inv_count = 1.f / count;\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    mean[j] *= inv_count;\n                }\n\n                // Compute the variance.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    // Convert to float.\n                    float x_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage[i]);\n\n                    // Is it a valid pixel?\n                    float is_valid = i < static_cast<int>(count) ? 1.f : 0.f;\n                    // Update the mean and m2 using deltas.\n                    #pragma unroll\n                    for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                        m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid;\n                    }\n                }\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                float is_pixel_valid = (((unsigned int)idx <\n                                         (unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f;\n\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];\n                ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0)*params.c]);\n\n                // The offset to store in SMEM.\n                const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                // Update the mean and m2 using deltas.\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    float delta0 = x_math[j] - mean[j];\n                    mean[j] += delta0 * inv_count;\n                    float delta1 = x_math[j] - mean[j];\n                    m2[j] += delta0 * delta1 * is_pixel_valid;\n                }\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        float m1[ELEMENTS_PER_LDG];\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m1[i] = mean[i] * count;\n        }\n\n        // Run the parallel sum accross the CTA to get the local sum.\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, m1, thread_in_cta_nhw);\n        __syncthreads();\n\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(m1, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // Adjust the variance.\n        float inv_cta_count = 1.f / static_cast<float>(cta_count);\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            float mean_diff = m1[i]*inv_cta_count - mean[i];\n            m2[i] = m2[i] + mean_diff * mean_diff * count;\n        }\n\n        // Run the parallel sum accross the CTA to get the local adjusted variance.\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, m2, thread_in_cta_nhw);\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, m1);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, m2);\n        }\n\n        // The memory location to store the number of pixels per CTA.\n        int *gmem_counts = &params.gmem_counts[c_blk_index*gridDim.x];\n        if (threadIdx.x == 0) {\n            gmem_counts[blockIdx.x] = cta_count;\n        }\n\n        // Read the bias and scale.\n        float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG];\n        if (is_valid_c) {\n            read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);\n            read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the mean to compute the global mean.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m1[i] = 0.f;\n        }\n\n        // Build the global mean.\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp, gmem_sums, idx);\n            add(m1, tmp);\n        }\n\n        if (params.sync_iters>0)\n        {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, m1, thread_in_cta_nhw);\n        }\n        __syncthreads();\n\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(m1, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // Normalize the mean.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m1[i] = m1[i] * params.svar_inv_count;\n        }\n\n        // Reset the variance.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m2[i] = 0.f;\n        }\n\n        // for add+relu fusion\n        const uint16_t *gmem_src1 = nullptr;\n        if (USE_ADD_RELU) {\n            gmem_src1 = &params.gmem_src1[thread_c];\n        }\n\n        // Build the global variance.\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            // Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.\n            float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp_mean, &gmem_sums[                           0], idx);\n            read_from_gmem(tmp_var,  &gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx);\n\n            // Read the number of pixels visited by a given CTA.\n            cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]);\n\n            // Compute the diff to update the variance.\n            float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast<float>(cta_count);\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                mean_diff[i] = m1[i] - tmp_mean[i]*inv_cta_count;\n            }\n\n            // Update the variance.\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                m2[i] += tmp_var[i] + mean_diff[i]*mean_diff[i]*static_cast<float>(cta_count);\n            }\n        }\n\n        if (params.sync_iters>0)\n        {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, m2, thread_in_cta_nhw);\n        }\n        __syncthreads();\n\n        read_from_smem(m2, smem, thread_in_cta_c);\n\n        // Finalize the stddev.\n        // becasue saved var and running var may have different denominator, we don't do it here\n        // scale_(m2, inv_count);\n\n        // store the saved mean/var\n        float svarinv[ELEMENTS_PER_LDG];\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps);\n        }\n        if (is_valid_for_saving) {\n            write_to_gmem(params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG, m1);\n            write_to_gmem(params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG, svarinv);\n        }\n\n        // store the running mean/var\n        float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG];\n        zero_array(rmean);\n        zero_array(rvar);\n        if (params.exp_avg_factor != 1.f && is_valid_for_saving) {\n            read_from_gmem(rmean, params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG);\n            read_from_gmem(rvar, params.gmem_running_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] +   \\\n                params.exp_avg_factor * m1[i];\n            rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] +     \\\n                params.exp_avg_factor * (m2[i] * params.rvar_inv_count);\n        }\n        if (is_valid_for_saving) {\n            write_to_gmem(params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG, rmean);\n            write_to_gmem(params.gmem_running_var, thread_c/ELEMENTS_PER_LDG, rvar);\n        }\n\n        // Update the scale with the stddev and eps.\n        multiply(scale, svarinv);\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +\n                                     ((params.nhw + 31) & ~31) * 2 * c_blk_index;\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                const bool is_valid = is_valid_nhw && is_valid_c;\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n\n                // Normalize and apply activation function\n                normalize(x_math, bias, scale, m1);\n                if (USE_ADD_RELU) {\n                    float x1_math[ELEMENTS_PER_LDG];\n                    ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);\n                    add(x_math, x1_math);\n                    unsigned int relu_mask;\n                    int lane_id = threadIdx.x & 31;\n                    #pragma unroll\n                    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                        bool rectified = x_math[i] < 0.0F;\n                        unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);\n                        if (lane_id == i) {\n                            // Thread 0 remembers the relu_mask from the first time through this\n                            // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.\n                            relu_mask = local_relu_mask;\n                        }\n                        if (rectified) {\n                            x_math[i] = 0.0F;\n                        }\n                    }\n                    if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {\n                        gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;\n                    }\n                } else if (USE_RELU) {\n                    relu_activation(x_math);\n                }\n\n                // Write back.\n                if (is_valid) {\n                    stg_stream(&gmem_dst[idx*params.c], x_math);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            #pragma unroll 2\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                const bool is_valid = is_valid_nhw && is_valid_c;\n\n                // Read from SMEM.\n                const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];\n                read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                float x_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n\n                // Normalize and apply activation function\n                normalize(x_math, bias, scale, m1);\n                if (USE_ADD_RELU) {\n                    float x1_math[ELEMENTS_PER_LDG];\n                    ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);\n                    add(x_math, x1_math);\n                    unsigned int relu_mask;\n                    int lane_id = threadIdx.x & 31;\n                    #pragma unroll\n                    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                        bool rectified = x_math[i] < 0.0F;\n                        unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);\n                        if (lane_id == i) {\n                            relu_mask = local_relu_mask;\n                        }\n                        if (rectified) {\n                            x_math[i] = 0.0F;\n                        }\n                    }\n                    if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {\n                        gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;\n                    }\n                } else if (USE_RELU) {\n                    relu_activation(x_math);\n                }\n\n                // Write back.\n                if (is_valid) {\n                    stg_stream(&gmem_dst[idx*params.c], x_math);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormBwdParams {\n    // The input/output tensors.\n    uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1;\n    // dscale/dbias\n    float *gmem_dscale, *gmem_dbias;\n    // The scale and bias.\n    float *gmem_scale, *gmem_bias;\n    // The mean/inv-var saved from fwd pass\n    float *gmem_saved_mean, *gmem_saved_var;\n    // ReLU bitmask\n    unsigned int *gmem_relu_bitmask;\n    // The dimensions.\n    int nhw, c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    float svar_inv_count;\n    // The buffer to do the reduction for dscale and dbias\n    float *gmem_sums;\n    // The counters of retired CTAs.\n    int *gmem_retired_ctas;\n    // outer loop count\n    int outer_loops;\n    // number of CTAs along .x dimension\n    int c_blks;\n\n    void* my_data;\n    void* pair_datas[4];\n    int magic;\n    int sync_iters;\n    float wgrad_coeff;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N],\n                              const float (&mean_var_scale_bias)[N],\n                              const float (&var_scale)[N], bool valid_data) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];\n        if ((y <= 0.f) && valid_data) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        if ((y[j] <= 0.f) && valid_data) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        if (rectified[j] && valid_data) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N],\n                                     const float (&x)[N],\n                                     const float (&mean_var_scale_bias)[N],\n                                     const float (&var_scale)[N]) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];\n        if (y <= 0.f) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        if (y[j] <= 0.f) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N],\n                                const float (&dy)[N], const float (&x)[N],\n                                const float (&mean)[N], float inv_count) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float delta0 = dy[j] - dbias[j];\n        dbias[j] += delta0 * inv_count;\n        delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j];\n        dscale[j] += delta0 * inv_count;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N],\n                            const float (&var)[N], const float (&x)[N], const float (&mean)[N],\n                            const float (&dscale)[N], const float (&dbias)[N], float inv_count) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float tmp1 = dy[j] - (dbias[j]* inv_count);\n        float tmp2 = dscale[j] * inv_count;\n        float tmp3 = x[j] - mean[j];\n        dx[j] = var[j] * (tmp1 - (tmp2 * tmp3));\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n    PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n        // Compute the NHW coordinate of the thread in the CTA.\n        const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n        // Registers to store the mean used for entire duration\n        float mean[ELEMENTS_PER_LDG];\n        zero_array(mean);\n        if (is_valid_c) {\n            read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // accumulation related registers\n        float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointers to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n        const uint16_t *gmem_dy = &params.gmem_dy[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute sum across them\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized\n            int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -\n                         PIXELS_PER_CTA_IN_SMEM * gridDim.x;\n            cta_nhw_regs += offset;\n            cta_nhw_smem += offset;\n        }\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));\n\n            // Read the elements from memory.\n            float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                zero_array(x_storage[i]);\n                zero_array(dy_storage[i]);\n                is_valid[i] = 0.f;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    if (loop_i == OUTER_LOOPS - 1) {\n                        ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                    } else {\n                        ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg(dy_storage[i], &gmem_dy[idx*params.c]);\n                    }\n                    is_valid[i] = 1.f;\n                }\n            }\n\n            // Do the math.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float and update\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                // Update the count.\n                count += is_valid[i];\n                // Invert the count.\n                float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                bool is_pixel_valid = (((unsigned int)idx <\n                                        (unsigned int)params.nhw) && is_valid_c);\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                                  dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                zero_array(x_storage_local);\n                zero_array(dy_storage_local);\n                if (is_pixel_valid) {\n                    ldg_stream(x_storage_local, &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);\n                }\n\n                // The offset to store in SMEM.\n                int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                to_float(dy_math, dy_storage_local);\n\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            dbias[i] *= count;\n            dscale[i] *= count;\n        }\n\n        // dscale parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dscale, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dbias, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, dscale);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the accumulators for global summation\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // Build the global accumulation\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp1, gmem_sums,                              idx);\n            read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);\n\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                dscale[i] += tmp1[i];\n                dbias[i] += tmp2[i];\n            }\n        }\n\n        // dscale parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n\n        // inv-var\n        float var[ELEMENTS_PER_LDG];\n        zero_array(var);\n        if (is_valid_c) {\n            read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // Normalize the dscale.\n        multiply(dscale, var);\n\n        // store dscale/dbias\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        if (is_valid_for_saving) {\n            if (params.sync_iters>0)\n            {\n                scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n                scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n            } else {\n                write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);\n                write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);\n            }\n        }\n\n        // scale\n        float scale[ELEMENTS_PER_LDG];\n        zero_array(scale);\n        if (is_valid_c) {\n            read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // Further normalize the dscale to be used in dx calculation\n        multiply(dscale, var);\n        // scale the inv-var as well, afterwards\n        multiply(var, scale);\n\n        // inverse count\n        float inv_count = params.svar_inv_count;\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                float dx[ELEMENTS_PER_LDG];\n                bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                // Write back.\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                if (is_valid) {\n                    // Read from SMEM.\n                    int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                        dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                    read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage_local);\n                    to_float(dy_math, dy_storage_local);\n\n                    float dx[ELEMENTS_PER_LDG];\n                    bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                    // Write back.\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n    PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n        // Compute the NHW coordinate of the thread in the CTA.\n        const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n\n        // Registers to store the mean/var/scale/bias used for the entire duration\n        // Register usage optimizations:\n        // 1. Can combine bias - (mean * var * scale) into a single register\n        // 2. Can combine var * scale into a single register\n        float varscale[ELEMENTS_PER_LDG];\n        zero_array(varscale);\n        if (is_valid_c) {\n            read_from_gmem(varscale, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        float tmp[ELEMENTS_PER_LDG];\n        zero_array(tmp);\n        if (is_valid_c) {\n            read_from_gmem(tmp, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(varscale, tmp);\n        float mean[ELEMENTS_PER_LDG];\n        zero_array(mean);\n        if (is_valid_c) {\n            read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);\n        }\n        zero_array(tmp);\n        if (is_valid_c) {\n            read_from_gmem(tmp, params.gmem_bias, thread_c/ELEMENTS_PER_LDG);\n        }\n        float mean_var_scale_bias[ELEMENTS_PER_LDG];\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]);\n        }\n\n        // accumulation related registers\n        float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointers to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n        const uint16_t *gmem_dy = &params.gmem_dy[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute sum across them\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized\n            int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -\n                         PIXELS_PER_CTA_IN_SMEM * gridDim.x;\n            cta_nhw_regs += offset;\n            cta_nhw_smem += offset;\n        }\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));\n\n            // Read the elements from memory.\n            float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                zero_array(x_storage[i]);\n                zero_array(dy_storage[i]);\n                is_valid[i] = 0.f;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    if (loop_i == OUTER_LOOPS - 1) {\n                        ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                    } else {\n                        ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg(dy_storage[i], &gmem_dy[idx*params.c]);\n                    }\n                    is_valid[i] = 1.f;\n                }\n            }\n\n            // Do the math.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float and update\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                // Update the count.\n                count += is_valid[i];\n                // Invert the count.\n                float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                bool is_pixel_valid = (((unsigned int)idx <\n                                        (unsigned int)params.nhw) && is_valid_c);\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                                  dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                zero_array(x_storage_local);\n                zero_array(dy_storage_local);\n                if (is_pixel_valid) {\n                    ldg_stream(x_storage_local, &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);\n                }\n\n                // The offset to store in SMEM.\n                int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                to_float(dy_math, dy_storage_local);\n\n                relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            dbias[i] *= count;\n            dscale[i] *= count;\n        }\n\n        // dscale parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dscale, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dbias, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, dscale);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the accumulators for global summation\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // Build the global accumulation\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp1, gmem_sums,                              idx);\n            read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);\n\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                dscale[i] += tmp1[i];\n                dbias[i] += tmp2[i];\n            }\n        }\n\n        // dscale parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n\n        // Normalize the dscale.\n        float var[ELEMENTS_PER_LDG];\n        zero_array(var);\n        if (is_valid_c) {\n            read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n\n        // store dscale/dbias\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        if (is_valid_for_saving) {\n            if (params.sync_iters>0)\n            {\n                scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n                scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n            } else {\n                write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);\n                write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);\n            }\n        }\n\n        // Further normalize the dscale to be used in dx calculation\n        float scale[ELEMENTS_PER_LDG];\n        zero_array(scale);\n        if (is_valid_c) {\n            read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n        // scale the inv-var as well, afterwards\n        multiply(var, scale);\n\n        // inverse count\n        float inv_count = params.svar_inv_count;\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n                relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);\n\n                float dx[ELEMENTS_PER_LDG];\n                bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                // Write back.\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                if (is_valid) {\n                    // Read from SMEM.\n                    int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                        dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                    read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage_local);\n                    to_float(dy_math, dy_storage_local);\n                    relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);\n\n                    float dx[ELEMENTS_PER_LDG];\n                    bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                    // Write back.\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n    PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n        // Compute the NHW coordinate of the thread in the CTA.\n        const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n        float mean[ELEMENTS_PER_LDG];\n        zero_array(mean);\n        if (is_valid_c) {\n            read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // accumulation related registers\n        float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointers to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n        const uint16_t *gmem_dy = &params.gmem_dy[thread_c];\n        uint16_t *gmem_dst1 = &params.gmem_dst1[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute sum across them\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized, offset is evenly divisible by 32\n            int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x -\n                          params.nhw) & ~31;\n            cta_nhw_regs -= offset;\n            cta_nhw_smem -= offset;\n        }\n\n        const unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +\n                                      ((params.nhw + 31) & ~31) * 2 * c_blk_index;\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));\n\n            int lane_id = threadIdx.x & 31;\n\n            // Read the elements from memory.\n            float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n            unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                zero_array(x_storage[i]);\n                zero_array(dy_storage[i]);\n                is_valid[i] = 0.f;\n                const bool is_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                if (is_valid_nhw) {\n                    if (is_valid_c) {\n                        if (loop_i == OUTER_LOOPS - 1) {\n                            ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                            ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                        } else {\n                            ldg(x_storage[i], &gmem_src[idx*params.c]);\n                            ldg(dy_storage[i], &gmem_dy[idx*params.c]);\n                        }\n                        is_valid[i] = 1.f;\n                    }\n\n                    if (lane_id < ELEMENTS_PER_LDG) {\n                        relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id];\n                    }\n                }\n            }\n\n            // Do the math.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                // Convert to float and update\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                bool rectified[ELEMENTS_PER_LDG];\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) &\n                                    (1U << lane_id)) != 0);\n                }\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                // Update the count.\n                count += is_valid[i];\n                // Invert the count.\n                float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                relu_bwd(dy_math, rectified, is_valid[i]);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n\n                // Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version\n                from_float(dy_storage[i], dy_math);\n\n                // dZ for elementwise add\n                if (is_valid[i]) {\n                    if (loop_i == OUTER_LOOPS - 1) {\n                        stg_stream(&gmem_dst1[idx*params.c], dy_storage[i]);\n                    } else {\n                        stg(&gmem_dst1[idx*params.c], dy_storage[i]);\n                    }\n                }\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_pixel_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c;\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                                  dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                unsigned int relu_mask;\n                int lane_id = threadIdx.x & 31;\n                zero_array(x_storage_local);\n                zero_array(dy_storage_local);\n                if (is_pixel_valid_nhw) {\n                    if (is_valid_c) {\n                        ldg_stream(x_storage_local, &gmem_src[idx*params.c]);\n                        ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);\n                    }\n                    if (lane_id < ELEMENTS_PER_LDG) {\n                        relu_mask = gmem_relu_bitmask[idx * 2 + lane_id];\n                    }\n                }\n                bool rectified[ELEMENTS_PER_LDG];\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) &\n                                    (1U << lane_id)) != 0);\n                }\n\n                // The offset to store in SMEM.\n                int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                to_float(dy_math, dy_storage_local);\n\n                relu_bwd(dy_math, rectified, is_pixel_valid);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n\n                from_float(dy_storage_local, dy_math);\n                // dZ for elementwise add\n                if (is_pixel_valid) {\n                    stg_stream(&gmem_dst1[idx*params.c], dy_storage_local);\n                }\n                // only store the 'relu-dgrad'ed version!\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            dbias[i] *= count;\n            dscale[i] *= count;\n        }\n\n        // dscale parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dscale, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dbias, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, dscale);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the accumulators for global summation\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // Build the global accumulation\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp1, gmem_sums,                              idx);\n            read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);\n\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                dscale[i] += tmp1[i];\n                dbias[i] += tmp2[i];\n            }\n        }\n\n        // dscale parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n\n        // Normalize the dscale.\n        float var[ELEMENTS_PER_LDG];\n        zero_array(var);\n        if (is_valid_c) {\n            read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n\n        // store dscale/dbias\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        if (is_valid_for_saving) {\n            if (params.sync_iters>0)\n            {\n                scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n                scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n            } else {\n                write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);\n                write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);\n            }\n        }\n\n        // Further normalize the dscale to be used in dx calculation\n        float scale[ELEMENTS_PER_LDG];\n        zero_array(scale);\n        if (is_valid_c) {\n            read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n        // scale the inv-var as well, afterwards\n        multiply(var, scale);\n\n        // inverse count\n        float inv_count = params.svar_inv_count;\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                float dx[ELEMENTS_PER_LDG];\n                bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                // Write back.\n                if (is_valid) {\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                float y[ELEMENTS_PER_LDG];\n                zero_array(y);\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage[i], &gmem_dst1[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                if (is_valid) {\n                    // Read from SMEM.\n                    int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                        dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                    read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage_local);\n                    to_float(dy_math, dy_storage_local);\n\n                    float dx[ELEMENTS_PER_LDG];\n                    bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                    // Write back.\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/layer_norm/ln_api.cpp",
    "content": "#include <torch/extension.h>\n#include \"ATen/cuda/CUDAContext.h\"\n\nvoid ln_fwd_cuda(at::Tensor &y, at::Tensor &mu, at::Tensor &rsigma,\n                 const at::Tensor &x, const at::Tensor &gamma,\n                 const at::Tensor &beta, const float epsilon, const int rows, const int cols,\n                 cudaStream_t stream);\n\nvoid ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,\n                 const at::Tensor &dw, const at::Tensor &x,\n                 const at::Tensor &mu, const at::Tensor &rsigma,\n                 const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream);\n\n\nstd::vector<at::Tensor> ln_fwd(const at::Tensor &x,      // BxSxhidden_size\n                               const at::Tensor &gamma,   // hidden_size\n                               const at::Tensor &beta,   // hidden_size\n                               const float epsilon\n) {\n\n    TORCH_CHECK(x.is_cuda())\n    TORCH_CHECK(gamma.is_cuda())\n    TORCH_CHECK(beta.is_cuda())\n\n    TORCH_CHECK(x.is_contiguous());\n    auto sizes = x.sizes();\n    TORCH_CHECK(sizes.size() == 2);\n\n    const int rows = sizes[0];\n    const int cols = sizes[1];\n\n    auto dtype = x.scalar_type();\n\n    TORCH_CHECK(gamma.dtype() == dtype);\n    TORCH_CHECK(beta.dtype() == dtype);\n\n    TORCH_CHECK(gamma.sizes() == beta.sizes());\n    TORCH_CHECK(gamma.numel() == cols);\n\n    TORCH_CHECK(epsilon >= 0.f);\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    auto y = torch::empty_like(x);\n\n    auto opts = x.options();\n\n    auto mu = torch::empty({rows}, opts.dtype(torch::kFloat32));\n    auto rsigma = torch::empty({rows}, opts.dtype(torch::kFloat32));\n\n    ln_fwd_cuda(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, stream);\n\n    return {y, mu, rsigma};\n}\n\n\n\nstd::vector<at::Tensor> ln_bwd(const at::Tensor &dw,     // BxSxhidden_size\n                               const at::Tensor &x,      // BxSxhidden_size\n                               const at::Tensor &mu,     // BxS, FP32!\n                               const at::Tensor &rsigma, // BxS, FP32!\n                               const at::Tensor &gamma   // hidden_size\n) {\n\n  TORCH_CHECK(x.is_cuda());\n  TORCH_CHECK(dw.is_cuda());\n  TORCH_CHECK(mu.is_cuda());\n  TORCH_CHECK(rsigma.is_cuda());\n  TORCH_CHECK(gamma.is_cuda());\n\n  TORCH_CHECK(x.is_contiguous());\n  TORCH_CHECK(dw.is_contiguous());\n\n  auto sizes = x.sizes();\n  TORCH_CHECK(sizes.size() == 2);\n  TORCH_CHECK(dw.sizes() == sizes);\n  auto rows = sizes[0];\n  auto cols = sizes[1];\n  \n  auto dtype = x.scalar_type();\n  TORCH_CHECK(dw.dtype() == dtype);\n  TORCH_CHECK(gamma.dtype() == dtype);\n  TORCH_CHECK(mu.dtype() == torch::kFloat32);\n  TORCH_CHECK(rsigma.dtype() == torch::kFloat32);\n  TORCH_CHECK(mu.sizes() == rsigma.sizes());\n  TORCH_CHECK(mu.numel() == rows);\n\n  TORCH_CHECK(gamma.numel() == cols);\n\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n  auto dx = torch::empty_like(x);\n  auto dgamma = torch::empty_like(gamma);\n  auto dbeta = torch::empty_like(gamma);\n  \n  ln_bwd_cuda(dx, dgamma, dbeta, dw, x, mu, rsigma, gamma, rows, cols, stream);\n\n  return {dx, dgamma, dbeta};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.doc() = \"CUDA LayerNorm\"; // optional module docstring\n  m.def(\"ln_fwd\", &ln_fwd, \"Run LayerNorm forward kernel\");\n  m.def(\"ln_bwd\", &ln_bwd, \"Run LayerNorm backward kernel\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu",
    "content": "#include \"utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n\ntemplate<typename Ktraits>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(void * __restrict__ dx_,\n                                                                          void * __restrict__ dg_,\n                                                                          void * __restrict__ db_,\n                                                                          const void * __restrict__ dw_,\n                                                                          const void * __restrict__ x_,\n                                                                          const void * __restrict__ mu_,\n                                                                          const void * __restrict__ rs_,\n                                                                          const void * __restrict__ g_,\n                                                                          const int rows\n                                                                        ){\n  using Vec = typename Ktraits::Vec;\n\n  enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };\n  enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n  enum { WARPS_M = Ktraits::WARPS_M };\n  enum { WARPS_N = Ktraits::WARPS_N };\n  enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n  enum { COLS = Ktraits::COLS };\n  enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n  enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };\n  static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, \"\");\n  enum { NUM_ELTS = Vec::NUM_ELTS };\n  using vec_t = typename Ktraits::vec_t;\n  using base_t = typename Ktraits::base_t;\n  using compute_t = typename Ktraits::compute_t;\n  const int tidx = threadIdx.x;\n  const int bidx = blockIdx.x;\n  const int lane = tidx % THREADS_PER_WARP;\n  const int warp = tidx / THREADS_PER_WARP;\n  const int warp_m = warp / Ktraits::WARPS_N;\n  const int warp_n = warp % Ktraits::WARPS_N;\n  const int tid_r = warp_n * THREADS_PER_WARP + lane;\n\n  const int r = bidx * Ktraits::ROWS_PER_CTA + warp_m;\n  const int c = warp_n * THREADS_PER_WARP + lane;\n\n  const char *dw_ptr = static_cast<const char *>(dw_);\n  const char *x_ptr = static_cast<const char *>(x_);\n  const char *g_ptr = static_cast<const char *>(g_);\n  char *dx_ptr = static_cast<char *>(dx_);\n  const compute_t *mu_ptr = static_cast<const compute_t *>(mu_);\n  const compute_t *rs_ptr = static_cast<const compute_t *>(rs_);\n  static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS, \"\");\n\n  // smem for final reduction\n  //__shared__ compute_t smem_[ROWS_PER_CTA * COLS];\n  extern __shared__ compute_t smem_[];\n  // static_assert(sizeof(smem_dw_sum) == 32*1024,\"\");\n  // Using the grid stride loop we can assign multiple rows to each thread\n  // by using a number of CTAs smaller than rows / ROWS_PER_CTA\n  // We accumulate them here, one in smem, one in registers, because the smem\n  // capacity is limited compute_t * dw_sum = &smem_dw_sum[warp_m * COLS + tid_r\n  // * LDGS * NUM_ELTS];\n  compute_t dwy_sum[LDGS * NUM_ELTS];\n  compute_t dw_sum[LDGS * NUM_ELTS];\n\n  memset(dwy_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);\n  memset(dw_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);\n  // Debug 8 rows, 4B, 1024 cols\n\n  __shared__ compute_t smem_mdy[ROWS_PER_CTA * WARPS_N];\n  __shared__ compute_t smem_mdyy[ROWS_PER_CTA * WARPS_N];\n  compute_t *mdy_shared = &smem_mdy[warp_m * WARPS_N];\n  compute_t *mdyy_shared = &smem_mdyy[warp_m * WARPS_N];\n\n  constexpr float rn = 1.f / float(COLS);\n  Vec gamma[LDGS];\n  int col = c;\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n    gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);\n    col += Ktraits::THREADS_PER_ROW;\n  }\n  // TODO if ROWS_PER_CTA does not divice rows, we might get divergence in the\n  // last blocks with syncthreads!\n  // grid stride over rows\n  #pragma unroll 1\n  for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {\n    const compute_t mu_r = mu_ptr[row];\n    const compute_t rs_r = rs_ptr[row];\n    Vec dw[LDGS], x[LDGS], dx[LDGS];\n    int col = c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dw[it].load_from(dw_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += THREADS_PER_ROW;\n    }\n    // local reductions\n    compute_t dy[LDGS * NUM_ELTS];\n    compute_t y[LDGS * NUM_ELTS];\n\n    compute_t mdy_local = 0.f;\n    compute_t mdyy_local = 0.f;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < Vec::NUM_ELTS; jt++) {\n        compute_t x_tmp = x[it].data.elt[jt];\n        compute_t y_tmp = rs_r * (x_tmp - mu_r);\n        compute_t dy_tmp = gamma[it].data.elt[jt] * dw[it].data.elt[jt];\n        compute_t dw_tmp = dw[it].data.elt[jt];\n\n        mdy_local += dy_tmp;\n        mdyy_local += dy_tmp * y_tmp;\n\n        dy[it * NUM_ELTS + jt] = dy_tmp;\n        y[it * NUM_ELTS + jt] = y_tmp;\n\n        dwy_sum[it * NUM_ELTS + jt] += dw_tmp * y_tmp;\n        dw_sum[it * NUM_ELTS + jt] += dw_tmp;\n      }\n    }\n\n    // reduction across row for mdy, mdyy\n    if (WARPS_N == 1) { // no need to go through smem!\n#pragma unroll\n      for (int it = 1; it < THREADS_PER_WARP; it *= 2) {\n        mdy_local += __shfl_xor_sync(uint32_t(-1), mdy_local, it);\n        mdyy_local += __shfl_xor_sync(uint32_t(-1), mdyy_local, it);\n      }\n\n      mdy_local *= rn;\n      mdyy_local *= rn;\n\n    } else {\n\n#pragma unroll\n      for (int it = 16; it > 0; it /= 2) {\n        mdy_local += __shfl_down_sync(uint32_t(-1), mdy_local, it);\n        mdyy_local += __shfl_down_sync(uint32_t(-1), mdyy_local, it);\n      } // lane 0 holds the result!\n\n      if (lane == 0) {\n        mdy_shared[warp_n] = mdy_local;\n        mdyy_shared[warp_n] = mdyy_local;\n      }\n\n      __syncthreads();\n      if (warp_n == 0 && lane == 0) {\n        mdy_local = 0.f;\n        mdyy_local = 0.f;\n        for (int it = 0; it < WARPS_N; it++) {\n          mdy_local += mdy_shared[it];\n          mdyy_local += mdyy_shared[it];\n        }\n        mdy_shared[0] = mdy_local;\n        mdyy_shared[0] = mdyy_local;\n      }\n      __syncthreads();\n\n      mdy_local = mdy_shared[0] * rn;\n      mdyy_local = mdyy_shared[0] * rn;\n    }\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        compute_t dy_tmp = dy[it * NUM_ELTS + jt];\n        compute_t y_tmp = y[it * NUM_ELTS + jt];\n        compute_t dx_tmp =\n            compute_t(rs_r) * (dy_tmp - mdyy_local * y_tmp - mdy_local);\n        dx[it].data.elt[jt] = dx_tmp;\n      }\n    }\n\n    col = c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dx[it].store_to(dx_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += Ktraits::THREADS_PER_ROW;\n    }\n\n  } // end: grid stride loop\n\n  // Finalize reduction of part dgamma and dbeta for this CTA\n  // by reducing over the rows held across the WARPS_M warps\n\n  enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };\n  static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, \"\");\n\n  compute_t *smem_write;\n\n  smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n    for (int jt = 0; jt < NUM_ELTS; jt++) {\n      smem_write[jt] = dw_sum[it * NUM_ELTS + jt];\n    }\n    smem_write += THREADS_PER_ROW * NUM_ELTS;\n  }\n  __syncthreads();\n  compute_t cta_dw_sum[NUM_RES];\n  memset(cta_dw_sum, 0, sizeof(compute_t) * NUM_RES);\n  for (int it = 0; it < ROWS_PER_CTA; it++) {\n    for (int jt = 0; jt < NUM_RES; jt++) {\n      cta_dw_sum[jt] += smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n    }\n  }\n  __syncthreads();\n\n  smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n    for (int jt = 0; jt < NUM_ELTS; jt++) {\n      smem_write[jt] = dwy_sum[it * NUM_ELTS + jt];\n    }\n    smem_write += THREADS_PER_ROW * NUM_ELTS;\n  }\n  __syncthreads();\n  compute_t cta_dwy_sum[NUM_RES];\n  memset(cta_dwy_sum, 0, sizeof(compute_t) * NUM_RES);\n  for (int it = 0; it < ROWS_PER_CTA; it++) {\n    for (int jt = 0; jt < NUM_RES; jt++) {\n      cta_dwy_sum[jt] +=\n          smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n    }\n  }\n\n  compute_t *dgamma_part = static_cast<compute_t *>(dg_) + bidx * COLS + tidx;\n  for (int jt = 0; jt < NUM_RES; jt++) {\n    *dgamma_part = cta_dwy_sum[jt];\n    dgamma_part += Ktraits::THREADS_PER_CTA;\n  }\n\n  compute_t *dbeta_part = static_cast<compute_t *>(db_) + bidx * COLS + tidx;\n  for (int jt = 0; jt < NUM_RES; jt++) {\n    *dbeta_part = cta_dw_sum[jt];\n    dbeta_part += Ktraits::THREADS_PER_CTA;\n  }\n}\n\ntemplate<typename Ktraits, typename out_t>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(void * __restrict__ dg_,\n                                                                                   void * __restrict__ db_,\n                                                                                   const void * __restrict__ dg_part_,\n                                                                                   const void * __restrict__ db_part_,\n                                                                                   const int rows\n                                                                                  ){\n    using Vec = typename Ktraits::Vec;\n    enum { NUM_ELTS = Vec::NUM_ELTS };\n\n\n    using vec_t = typename Ktraits::vec_t;\n    using base_t = typename Ktraits::base_t;\n    using compute_t = typename Ktraits::compute_t;\n\n    enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };\n    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n    enum { WARPS_M = Ktraits::WARPS_M };\n    enum { WARPS_N = Ktraits::WARPS_N };\n    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n    enum { COLS = Ktraits::COLS };\n    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n    enum {VEC_COLS = BYTES_PER_ROW / BYTES_PER_LDG};\n    //dbg\n    static_assert(VEC_COLS == COLS / NUM_ELTS, \"\"); \n    //static_assert(VEC_COLS == 1024,\"\");\n    const int tidx = threadIdx.x;\n    const int bidx = blockIdx.x;\n    const int lane = tidx % THREADS_PER_WARP;\n    const int warp = tidx / THREADS_PER_WARP;\n    const int warp_m = warp / Ktraits::WARPS_N;\n    const int warp_n = warp % Ktraits::WARPS_N;\n    const int tid_c = warp_n * THREADS_PER_WARP + lane;\n    const int c =bidx * THREADS_PER_ROW + tid_c;\n    const int r = warp_m;\n    \n    __shared__ compute_t smem_[(WARPS_M - 1) * THREADS_PER_ROW * NUM_ELTS];\n    \n    //Will probably run this with WARPS_N = 1 and grid = 1024 / (32*4) = 8, or NUM_ELTS=1 and grid = 32 \n    // and WARPS_M = 4 (or 1??)\n    for(int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW){\n      const char* dg_part_ptr = static_cast<const char*>(dg_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;\n      const char* db_part_ptr = static_cast<const char*>(db_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;\n\n      compute_t dg_sum[NUM_ELTS];\n      compute_t db_sum[NUM_ELTS];\n      memset(dg_sum, 0, sizeof(compute_t) * NUM_ELTS);\n      memset(db_sum, 0, sizeof(compute_t) * NUM_ELTS);\n      #pragma unroll\n      for(int row = r; row < rows;row += ROWS_PER_CTA){\n        Vec dg;\n        Vec db;\n        dg.load_from(dg_part_ptr);\n        db.load_from(db_part_ptr);\n        dg_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;\n        db_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;\n\n        #pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          dg_sum[jt] += dg.data.elt[jt];\n          db_sum[jt] += db.data.elt[jt];\n        }\n      }\n\n      // Finalize the reduction across rows of the CTA\n      compute_t * smem_write;\n      smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;\n\n      if (warp_m > 0) {\n#pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          *smem_write = dg_sum[jt];\n          smem_write+=THREADS_PER_ROW;\n        }\n      }\n      __syncthreads();\n      compute_t *smem_read ;\n      smem_read = smem_ + tid_c ;\n      if (warp_m == 0) {\n#pragma unroll\n        for (int it = 0; it < WARPS_M - 1; it++) {\n#pragma unroll\n          for (int jt = 0; jt < NUM_ELTS; jt++) {\n            dg_sum[jt] += *smem_read;\n            smem_read += THREADS_PER_ROW;\n          }\n        }\n      }\n\n      __syncthreads();\n\n      smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;\n\n      if (warp_m > 0) {\n#pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          *smem_write = db_sum[jt];\n          smem_write+=THREADS_PER_ROW;\n        }\n      }\n      __syncthreads();\n      smem_read = smem_ + tid_c;\n      if (warp_m == 0) {\n#pragma unroll\n        for (int it = 0; it < WARPS_M - 1; it++) {\n#pragma unroll\n          for (int jt = 0; jt < NUM_ELTS; jt++) {\n            db_sum[jt] += *smem_read;\n            smem_read += THREADS_PER_ROW;\n          }\n        }\n\n        using vout_t = typename Vec_type<sizeof(out_t) * NUM_ELTS>::Type;\n        union {\n          vout_t raw;\n          out_t elt[NUM_ELTS];\n        } dg_out, db_out;\n\n        // out_t dg_out[NUM_ELTS], db_out[NUM_ELTS];\n#pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          dg_out.elt[jt] = dg_sum[jt];\n          db_out.elt[jt] = db_sum[jt];\n        }\n        vout_t *dg_ptr = reinterpret_cast<vout_t *>(dg_) + col ;\n        vout_t *db_ptr = reinterpret_cast<vout_t *>(db_) + col ;\n        *dg_ptr = dg_out.raw;\n        *db_ptr = db_out.raw;\n      }\n    }\n}\n\ntemplate<typename scalar_t>\nvoid launch(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,\n                 at::Tensor &dgamma_part, at::Tensor &dbeta_part,\n                 const at::Tensor &dw, const at::Tensor &x,\n                 const at::Tensor &mu, const at::Tensor &rsigma,\n                 const at::Tensor &gamma, const int rows, const int cols, const int gridx, cudaStream_t stream){\n\n  if (cols == 1024) {\n    using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;\n\n    if (Ktraits::SMEM_BYTES >= 48 * 1024) {\n      AT_CUDA_CHECK(cudaFuncSetAttribute(\n          ln_bwd_kernel<Ktraits>, cudaFuncAttributeMaxDynamicSharedMemorySize,\n          Ktraits::SMEM_BYTES));\n    }\n\n    ln_bwd_kernel<Ktraits>\n        <<<gridx, Ktraits::THREADS_PER_CTA, Ktraits::SMEM_BYTES, stream>>>(\n            dx.data_ptr(), dgamma_part.data_ptr(), dbeta_part.data_ptr(),\n            dw.data_ptr(), x.data_ptr(), mu.data_ptr(), rsigma.data_ptr(),\n            gamma.data_ptr(), rows);\n\n    using Ktraits2 = Kernel_traits<float, 1024, 16, 1, 4>;\n\n    constexpr int grid2 =\n        DIVUP(1024, Ktraits2::THREADS_PER_ROW * Ktraits2::Vec::NUM_ELTS);\n\n    ln_bwd_finalize_kernel<Ktraits2, scalar_t>\n        <<<grid2, Ktraits2::THREADS_PER_CTA, 0, stream>>>(\n            dgamma.data_ptr(), dbeta.data_ptr(), dgamma_part.data_ptr(),\n            dbeta_part.data_ptr(), gridx);\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n  AT_CUDA_CHECK(cudaPeekAtLastError());\n}\n\nvoid ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,\n                 const at::Tensor &dw, const at::Tensor &x,\n                 const at::Tensor &mu, const at::Tensor &rsigma,\n                 const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream) {\n\n\n  const auto dtype = x.scalar_type();\n\n\n  const auto props = at::cuda::getCurrentDeviceProperties();\n  const int smCount = props->multiProcessorCount;\n  // Launch 2 CTAs per SM \n  const int grid = 2 * smCount;\n\n  //request workspace for two-step reduction. We always reduce in FP32.\n  auto opts = x.options();\n  auto dbeta_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));\n  auto dgamma_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));\n\n  if (dtype == torch::kFloat16) {\n    launch<half>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);\n  } else if (dtype == torch::kFloat32) {\n    launch<float>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n}"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu",
    "content": "#include \"utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n\ntemplate <typename Ktraits>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(\n    void *__restrict__ y_, void *__restrict__ mu_, void *__restrict__ rsigma_,\n    const void *__restrict__ x_, const void *__restrict__ gamma_,\n    const void *__restrict__ beta_, const float epsilon, int rows) {\n\n  using Vec = typename Ktraits::Vec;\n\n  using base_t = typename Ktraits::base_t;\n  using compute_t = typename Ktraits::compute_t;\n  enum { NUM_ELTS = Vec::NUM_ELTS };\n  enum { WARPS_N = Ktraits::WARPS_N };\n  enum { WARPS_M = Ktraits::WARPS_M };\n  enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n\n  enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n  enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };\n  static_assert(BYTES_PER_LDG == 16, \"\");\n\n  enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n  enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };\n  static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, \"\");\n\n  const int tidx = threadIdx.x;\n  const int bidx = blockIdx.x;\n  const int lane = tidx % THREADS_PER_WARP;\n  const int warp = tidx / THREADS_PER_WARP;\n  const int warp_n = warp % WARPS_N;\n  const int warp_m = warp / WARPS_N;\n\n  const int c = warp_n * THREADS_PER_WARP + lane;\n  const int r = bidx * ROWS_PER_CTA + warp_m;\n\n  const char *x_ptr = static_cast<const char *>(x_);\n\n  const char *g_ptr = static_cast<const char *>(gamma_);\n  const char *b_ptr = static_cast<const char *>(beta_);\n\n  char *y_ptr = static_cast<char *>(y_);\n  compute_t *mu_ptr = static_cast<compute_t *>(mu_);\n  compute_t *rs_ptr = static_cast<compute_t *>(rsigma_);\n\n  Vec gamma[LDGS];\n  Vec beta[LDGS];\n#pragma unroll\n  for (int it = 0, col = c; it < LDGS; it++) {\n    gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);\n    beta[it].load_from(b_ptr + col * BYTES_PER_LDG);\n    col += THREADS_PER_ROW;\n  }\n\n  constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);\n  for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {\n    Vec x[LDGS];\n#pragma unroll\n    for (int it = 0, col = c; it < LDGS; it++) {\n      x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += THREADS_PER_ROW;\n    }\n    compute_t xf[LDGS * NUM_ELTS];\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        xf[it * NUM_ELTS + jt] = compute_t(x[it].data.elt[jt]);\n      }\n    }\n\n    compute_t mu_local = 0.f;\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        mu_local += xf[it * NUM_ELTS + jt];\n      }\n    }\n\n#pragma unroll\n    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {\n      mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);\n    }\n    mu_local *= rn;\n    if(lane == 0){\n    mu_ptr[row] = mu_local;\n    }\n    compute_t var_local = 0.f;\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        compute_t diff = xf[it * NUM_ELTS + jt] - mu_local;\n        var_local += diff * diff;\n      }\n    }\n\n#pragma unroll\n    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {\n      var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);\n    }\n    compute_t rsigma = rsqrtf(var_local * rn + epsilon);\n    if(lane == 0){\n    rs_ptr[row] = rsigma;\n    }\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        base_t tmp = (rsigma * (xf[it * NUM_ELTS + jt] - mu_local));\n        x[it].data.elt[jt] = gamma[it].data.elt[jt] *  tmp + beta[it].data.elt[jt];\n      }\n    }\n\n#pragma unroll\n    for (int it = 0, col = c; it < LDGS; it++) {\n      x[it].store_to(y_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += THREADS_PER_ROW;\n    }\n  }\n}\ntemplate<typename scalar_t>\nvoid launch(\n    at::Tensor & y, // BxSxhidden_size\n    at::Tensor & mu,\n    at::Tensor & rsigma,\n    const at::Tensor & x, // BxSxhidden_size\n    const at::Tensor & gamma,\n    const at::Tensor & beta,\n    const float epsilon,\n    const int rows,\n    const int cols,\n    const int max_gridx,\n    cudaStream_t stream\n){\n\n  if (cols == 1024) {\n    using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;\n    const int grid =\n        std::min<int>(DIVUP(rows, Ktraits::ROWS_PER_CTA), max_gridx);\n\n    ln_fwd_kernel<Ktraits><<<grid, Ktraits::THREADS_PER_CTA, 0, stream>>>(\n        y.data_ptr(), mu.data_ptr(), rsigma.data_ptr(), x.data_ptr(),\n        gamma.data_ptr(), beta.data_ptr(), epsilon, rows);\n\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n  AT_CUDA_CHECK(cudaPeekAtLastError());\n}\n\nvoid ln_fwd_cuda(\n    at::Tensor & y, // BxSxhidden_size\n    at::Tensor & mu,\n    at::Tensor & rsigma,\n    const at::Tensor & x, // BxSxhidden_size\n    const at::Tensor & gamma,\n    const at::Tensor & beta,\n    const float epsilon,\n    const int rows, const int cols,\n    cudaStream_t stream\n){\n\n  const auto dtype = x.scalar_type();\n  const auto props = at::cuda::getCurrentDeviceProperties();\n  const int max_gridx = props->maxGridSize[0];\n\n  //TODO \n  // - Using dispatch macro costs 1% perf wtf?!?!\n  // - Tune FP32 warps\n  // - Add more sizes\n  if (dtype == torch::kFloat16) {\n    launch<half>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);\n  } else if (dtype == torch::kFloat32) {\n    launch<float>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n}"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/layer_norm/ln_kernel_traits.h",
    "content": "#pragma once\n\nconstexpr uint32_t THREADS_PER_WARP = 32;\n\ntemplate <typename dtype, int COLS_, int WARPS_M_, int WARPS_N_,\n          int BYTES_PER_LDG_ = 16>\nstruct Kernel_traits {\n  enum { WARPS_M = WARPS_M_ };\n  enum { WARPS_N = WARPS_N_ };\n  enum { COLS = COLS_ };\n  enum { BYTES_PER_LDG = BYTES_PER_LDG_ };\n\n  using Vec = Vec<dtype, BYTES_PER_LDG>;\n\n  using vec_t = typename Vec::vec_t;\n  using base_t = typename Vec::base_t;\n  using packed_t = typename Vec::packed_t;\n  using compute_t = typename Vec::compute_t;\n  using packed_compute_t = typename Vec::packed_compute_t;\n\n  enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };\n  enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };\n  enum { ROWS_PER_CTA = WARPS_M };\n\n  enum { BYTES_PER_ROW = COLS * sizeof(base_t) };\n  enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };\n  enum {SMEM_BYTES = ROWS_PER_CTA * COLS * sizeof(compute_t)};\n};\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/layer_norm/utils.cuh",
    "content": "#pragma once\n\n#include \"torch/extension.h\"\n#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK\n\n#define DIVUP(x, y) (((x) + ((y)-1)) / (y))\n\n#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...)                               \\\n  [&] {                                                                        \\\n    const auto &the_type = TYPE;                                               \\\n    /* don't use TYPE again in case it is an expensive or side-effect op */    \\\n    at::ScalarType _st = ::detail::scalar_type(the_type);                      \\\n    switch (_st) {                                                             \\\n      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)          \\\n      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)        \\\n    default:                                                                   \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(_st), \"'\");           \\\n    }                                                                          \\\n  }()\n\ntemplate <int Bytes> struct Vec_type {};\n\ntemplate <> struct Vec_type<16> {\n  using Type = uint4;\n  static __device__ inline Type zero() { return make_uint4(0, 0, 0, 0); }\n};\ntemplate <> struct Vec_type<8> {\n  using Type = uint2;\n  static __device__ inline Type zero() { return make_uint2(0, 0); }\n};\n\ntemplate <> struct Vec_type<4> {\n  using Type = uint32_t;\n  static __device__ inline Type zero() { return 0; }\n};\n\ntemplate <> struct Vec_type<2> {\n  using Type = uint16_t;\n  static __device__ inline Type zero() { return 0; }\n};\n\ntemplate <typename T> struct TypeInfo {\n  using base_t = T;\n  using packed_t = T;\n  using compute_t = float;\n  using packed_compute_t = float;\n};\n\ntemplate <> struct TypeInfo<half> {\n  using base_t = half;\n  using packed_t = half2;\n  using compute_t = float;\n  using packed_compute_t = float2;\n};\n\ntemplate <typename dtype, int Bytes> struct Vec {\n\n  using base_t = typename TypeInfo<dtype>::base_t;\n  using packed_t = typename TypeInfo<dtype>::packed_t;\n  using compute_t = typename TypeInfo<dtype>::compute_t;\n  using packed_compute_t = typename TypeInfo<dtype>::packed_compute_t;\n\n  static_assert(Bytes % sizeof(base_t) == 0, \"\");\n  static_assert(Bytes % sizeof(packed_t) == 0, \"\");\n  enum { BYTES_PER_THREAD = Bytes };\n  enum { NUM_ELTS = Bytes / sizeof(base_t) };\n  enum { NUM_PACKED = Bytes / sizeof(packed_t) };\n  using vec_t = typename Vec_type<Bytes>::Type;\n  using store_t = union {\n    vec_t raw;\n    base_t elt[NUM_ELTS];\n    packed_t packed[NUM_PACKED];\n  };\n  store_t data;\n\n  __device__ Vec() { data.raw = Vec_type<Bytes>::zero(); }\n\n  __device__ inline void load_from(const char *ptr) {\n    data.raw = *reinterpret_cast<const vec_t *>(ptr);\n  }\n\n  __device__ inline void load_or_zero(const char *ptr, const bool is_valid) {\n    data.raw = is_valid ? *reinterpret_cast<const vec_t *>(ptr)\n                        : Vec_type<Bytes>::zero();\n  }\n\n  __device__ inline void store_to(char *ptr) const {\n    *reinterpret_cast<vec_t *>(ptr) = data.raw;\n  }\n\n  __device__ inline void store_valid(char *ptr, const bool is_valid) const {\n    if (is_valid)\n      *reinterpret_cast<vec_t *>(ptr) = data.raw;\n  }\n};\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp",
    "content": "#include <torch/extension.h>\n#include <cuda_fp16.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace additive_mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const half*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t       bool \t\t\t\tuse_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(input.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Half, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(\n                                 is_training,\n                                 heads, \n                                 input, \n                                 use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\ntorch::Tensor bwd(\n\t\t               bool use_mask,\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n//  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n\t\t                 heads,\n                                 output_grads,\n                                 softmax_results, \n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace mask_softmax_dropout\n} // end namespace fused_softmax\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, \"Self Multihead Attention masked softmax dropout -- Forward.\");\n  m.def(\"backward\", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, \"Self Multihead Attention masked softmax dropout -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"softmax.h\"\n#include \"dropout.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace additive_mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n\t\t\t       bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const half*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = input.size(0);\n  const int   sequences      = attn_batches / heads;\n  const int   q_seq_len      = input.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = input.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n      softmax_success = dispatch_additive_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n  }\n\n\n  if (is_training) {\n    //use at:: function so that C++ version generates the same random mask as python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n\n  return {\n           dropout_results,  \n           dropout_mask, \n           softmax_results\n         };\n}\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results, \n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = output_grads.size(0);\n  const int   q_seq_len      = output_grads.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n//  torch::Tensor input_grads         = torch::empty_like(output_grads);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(\n                             static_cast<half*>(output_grads.data_ptr()), \n                             static_cast<half*>(output_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len, stream);\n//backward pass is completely in-place\n  return output_grads;\n}\n}\n}\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/dropout.h",
    "content": "#include <ATen/ATen.h>\n\n#ifdef OLD_GENERATOR\n#include <ATen/CUDAGenerator.h>\n#else\n#include <ATen/CUDAGeneratorImpl.h>\n#endif\n\n#include <ATen/cuda/CUDAContext.h>\n#include <curand_kernel.h>\n\n#include <THC/THCGeneral.h>\n\nconst int UNROLL = 4;\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\n__global__ void apex_fused_dropout_kernel(scalar_t const                *inputs,\n                                          scalar_t                      *outputs,\n                                          uint8_t                       *mask,\n                                          IndexType                      totalElements, \n\t\t                                  accscalar_t                    p, \n\t\t                                  std::pair<uint64_t, uint64_t>  seeds\n                                         ) \n{\n  accscalar_t pinv = accscalar_t(1)/p;\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n  curandStatePhilox4_32_10_t state;\n  curand_init(\n      seeds.first,\n      idx,\n      seeds.second,\n      &state);\n\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) {\n       float4 rand = curand_uniform4(&state);\n       scalar_t src[UNROLL];\n       rand.x = rand.x <= p;\n       rand.y = rand.y <= p;\n       rand.z = rand.z <= p;\n       rand.w = rand.w <= p;\n\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii] = inputs[li];\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n\t           outputs[li] = src[ii]*(&rand.x)[ii]*pinv;\n               mask[li]    = (uint8_t)(&rand.x)[ii];\n           }\n       }\n       __syncthreads();\n  }\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\n__global__ void apex_dropout_add_kernel(scalar_t const                *inputs,\n                                        scalar_t const                *add_inputs,\n                                        scalar_t                      *outputs,\n                                        uint8_t                       *mask,\n                                        IndexType                      totalElements, \n\t\t                                accscalar_t                    p, \n\t\t                                std::pair<uint64_t, uint64_t>  seeds\n                                       ) \n{\n  accscalar_t pinv = accscalar_t(1)/p;\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n  curandStatePhilox4_32_10_t state;\n  curand_init(\n      seeds.first,\n      idx,\n      seeds.second,\n      &state);\n\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) {\n       float4 rand = curand_uniform4(&state);\n       scalar_t src[UNROLL];\n       scalar_t add_src[UNROLL];\n       rand.x = rand.x <= p;\n       rand.y = rand.y <= p;\n       rand.z = rand.z <= p;\n       rand.w = rand.w <= p;\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii]     = inputs[li];\n               add_src[ii] = add_inputs[li];\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n\t           accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;\n\t           outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);\n               mask[li]    = (uint8_t)(&rand.x)[ii];\n           }\n       }\n       __syncthreads();\n  }\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\n__global__ void apex_add_kernel(          scalar_t const                *inputs,\n                                        scalar_t const                *add_inputs,\n                                        scalar_t                      *outputs,\n                                        IndexType                      totalElements\n                             ) \n{\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) {\n       scalar_t src[UNROLL];\n       scalar_t add_src[UNROLL];\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii]     = inputs[li];\n               add_src[ii] = add_inputs[li];\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n\t           outputs[li] = src[ii] + add_src[ii];\n           }\n       }\n       __syncthreads();\n  }\n}\n\ntemplate<typename scalar_t, \n\t\t typename accscalar_t, \n\t\t typename IndexType\n\t\t>\n__global__ void apex_masked_scale_kernel(scalar_t const *inputs, \n                                         scalar_t       *outputs, \n                                         uint8_t const  *mask, \n                                         IndexType       totalElements,\n                                         accscalar_t     scale\n                                        )\n{\n  IndexType idx          = blockIdx.x * blockDim.x + threadIdx.x;\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) \n  {\n       scalar_t src[UNROLL];\n       scalar_t msk[UNROLL];\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii] = static_cast<scalar_t>(inputs[li]);\n               msk[ii] = static_cast<scalar_t>(mask[li]);\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);\n           }\n       }\n  }\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\nvoid apex_fused_dropout_cuda(scalar_t const *inputs,\n                           scalar_t       *outputs,\n                           uint8_t        *mask,\n                           IndexType       totalElements, \n\t\t                   accscalar_t     p)\n{\n  auto gen = at::cuda::detail::getDefaultCUDAGenerator();\n  \n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  //number of times random will be generated per thread, to offset philox counter in thc random state\n  int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;\n  std::pair<uint64_t, uint64_t> rng_engine_inputs;\n  {\n    // See Note [Acquire lock when using random generators]\n#ifdef OLD_GENERATOR\n    std::lock_guard<std::mutex> lock(gen->mutex_);\n    rng_engine_inputs = gen->philox_engine_inputs(counter_offset);\n#else\n    std::lock_guard<std::mutex> lock(gen.mutex());\n    rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);\n#endif\n  }\n\n  apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs);\n  THCudaCheck(cudaGetLastError());\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\nvoid apex_dropout_add_cuda(scalar_t const *inputs,\n                           scalar_t const *add_inputs,\n                           scalar_t       *outputs,\n                           uint8_t        *mask,\n                           IndexType       totalElements, \n\t\t                   accscalar_t     p)\n{\n  auto gen = at::cuda::detail::getDefaultCUDAGenerator();\n  \n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  //number of times random will be generated per thread, to offset philox counter in thc random state\n  int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;\n  std::pair<uint64_t, uint64_t> rng_engine_inputs;\n  {\n    // See Note [Acquire lock when using random generators]\n#ifdef OLD_GENERATOR\n    std::lock_guard<std::mutex> lock(gen->mutex_);\n    rng_engine_inputs = gen->philox_engine_inputs(counter_offset);\n#else\n    std::lock_guard<std::mutex> lock(gen.mutex());\n    rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);\n#endif\n  }\n\n  apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);\n  THCudaCheck(cudaGetLastError());\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\nvoid apex_add_cuda(scalar_t const *inputs,\n                   scalar_t const *add_inputs,\n                   scalar_t       *outputs,\n                   IndexType       totalElements\n\t\t          )\n{\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements);\n  THCudaCheck(cudaGetLastError());\n}\n\ntemplate<typename scalar_t, \n         typename accscalar_t, \n         typename IndexType\n        >\nvoid apex_masked_scale_cuda(scalar_t const *inputs, \n                          scalar_t       *outputs, \n                          uint8_t const  *mask, \n                          IndexType       totalElements,\n                          accscalar_t     scale\n                         )\n{\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale);\n  THCudaCheck(cudaGetLastError());\n}\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace encdec {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n                               bool                 use_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs_q.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()        == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights_q.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim() == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()   == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs_q.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()        == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  \n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs_q, \n                                 inputs_kv, \n                                 input_weights_q, \n                                 input_weights_kv, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_q_results.dim()  == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_kv_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_q.dim()             == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights_q.dim()      == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()       == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()         == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_q_results.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_q.type().scalarType()             == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()       == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()         == at::ScalarType::Byte, \"Only BYTE is supported\");\n  \n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_q_results, \n                                 input_lin_kv_results, \n                                 inputs_q, \n                                 inputs_kv, \n                                 input_weights_q,\n                                 input_weights_kv,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec \n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::encdec::cublas_gemmex::fwd, \"Encdec Multihead Attention Forward.\");\n  m.def(\"backward\", &multihead_attn::encdec::cublas_gemmex::bwd, \"Encdec Multihead Attention Backward.\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace encdec {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs_q.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_q_results  = torch::empty({q_seq_len, sequences, output_lin_q_dim},  act_options);\n  torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);\n  torch::Tensor softmax_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_mask         = torch::empty({attn_batches, q_seq_len, k_seq_len},      mask_options);\n  torch::Tensor matmul2_results      = torch::empty({q_seq_len, attn_batches, head_dim},       act_options);\n  torch::Tensor outputs              = torch::empty_like(inputs_q, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_q_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(input_lin_kv_results.data_ptr());\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Q Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_q_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_q_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_kv_dim, \n                             batches_kv, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             k_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_kv_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim_q, \n                             batch_stride_q, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(softmax_results.data_ptr()), \n                               static_cast<at::Half*>(dropout_results.data_ptr()), \n                               static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                               dropout_elems,\n                               (1.0f - dropout_prob));\n  }\n \n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_lin_q_results, \n           input_lin_kv_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_q_grads          = torch::empty_like(inputs_q);\n  torch::Tensor input_kv_grads         = torch::empty_like(inputs_kv);\n  torch::Tensor input_weight_q_grads   = torch::empty_like(input_weights_q);\n  torch::Tensor input_weight_kv_grads  = torch::empty_like(input_weights_kv);\n  torch::Tensor output_weight_grads    = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads          = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads             = torch::empty_like(dropout_results);\n  at::Tensor input_lin_q_output_grads  = torch::empty_like(input_lin_q_results);\n  at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;\n  \n  auto q_lin_grads_ptr   = static_cast<half*>(input_lin_q_output_grads.data_ptr());\n  auto k_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr());\n  auto v_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim_kv, \n                             batch_stride_kv,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n                             static_cast<at::Half*>(matmul2_grads.data_ptr()),\n                             static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Input Linear Q Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_q, \n                             output_lin_q_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_q_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Q Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_q_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_q.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_q_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_kv, \n                             output_lin_kv_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_kv_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_kv_dim,\n                             batches_kv, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_kv_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_q_grads, \n           input_kv_grads, \n           input_weight_q_grads, \n           input_weight_kv_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec \n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace encdec_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n                               bool                 use_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs_q.dim()               == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()              == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()   == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights_q.dim()        == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim()       == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()         == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs_q.type().scalarType()              == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()             == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()       == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half, \"Only HALF is supported\");\n  \n  if (use_mask) {\n    AT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n    AT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs_q, \n                                 inputs_kv,\n\t\t\t\t\t\t\t\t lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t\t lyr_nrm_beta_weights,\n                                 input_weights_q, \n                                 input_weights_kv, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_q_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_kv_results.dim()  == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_mean.dim()          == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_invvar.dim()        == 1, \"expected 1D tensor\");\n  AT_ASSERTM(inputs_q.dim()              == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()             == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights_q.dim()       == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim()      == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()        == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_add_mask.dim()      == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()          == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_q_results.type().scalarType()   == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_kv_results.type().scalarType()  == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_mean.type().scalarType()          == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(lyr_nrm_invvar.type().scalarType()        == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(inputs_q.type().scalarType()              == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()             == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType()      == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()          == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  AT_ASSERTM(dropout_add_mask.type().scalarType()      == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  \n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_q_results, \n                                 input_lin_kv_results, \n                                 lyr_nrm_results,\n                                 lyr_nrm_mean,\n                                 lyr_nrm_invvar,\n                                 inputs_q, \n                                 inputs_kv, \n\t\t\t\t\t\t\t\t lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t\t lyr_nrm_beta_weights,\n                                 input_weights_q,\n                                 input_weights_kv,\n                                 output_weights,\n                                 dropout_mask,\n                                 dropout_add_mask,\n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec_norm_add \n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, \"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.\");\n  m.def(\"backward\", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, \"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace encdec_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   total_tokens_q    = batches_q * embed_dim;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options                   = inputs_q.options().requires_grad(false);\n  auto lyr_nrm_options               = act_options.dtype(torch::kFloat32);\n  auto mask_options                  = act_options.dtype(torch::kUInt8);\n  \n  torch::Tensor lyr_nrm_mean         = torch::empty({batches_q}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_invvar       = torch::empty({batches_q}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_results      = torch::empty_like(inputs_q, act_options);\n\n  torch::Tensor input_lin_q_results  = torch::empty({q_seq_len, sequences, output_lin_q_dim},  act_options);\n  torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);\n  torch::Tensor softmax_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_mask         = torch::empty({attn_batches, q_seq_len, k_seq_len},      mask_options);\n  torch::Tensor matmul2_results      = torch::empty({q_seq_len, attn_batches, head_dim},       act_options);\n  torch::Tensor output_lin_results   = torch::empty_like(inputs_q, act_options);\n  torch::Tensor dropout_add_mask     = torch::empty_like(inputs_q, mask_options);\n  torch::Tensor outputs              = torch::empty_like(inputs_q, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_q_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(input_lin_kv_results.data_ptr());\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Layer Norm\n  HostApplyLayerNorm<at::Half,float>(\n                             static_cast<at::Half*>(lyr_nrm_results.data_ptr()),\n                             static_cast<float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<float*>(lyr_nrm_invvar.data_ptr()),\n                             static_cast<const at::Half*>(inputs_q.data_ptr()),\n                             static_cast<int>(batches_q), // n1\n                             static_cast<int>(embed_dim), // n2\n                             1.0e-5,\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));\n\n  // Input Linear Q Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_q_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             //static_cast<const void*>(inputs_q.data_ptr()),\n                             static_cast<const void*>(lyr_nrm_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_q_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_kv_dim, \n                             batches_kv, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             k_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_kv_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim_q, \n                             batch_stride_q, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n  \n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(softmax_results.data_ptr()), \n                             static_cast<at::Half*>(dropout_results.data_ptr()), \n                             static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()), \n                             //static_cast<const half*>(dropout_results.data_ptr()), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // End-of-block Dropout-Add \n  if (is_training) {\n    apex_dropout_add_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                             static_cast<at::Half const*>(inputs_q.data_ptr()), \n                             static_cast<at::Half*>(outputs.data_ptr()), \n                             static_cast<uint8_t*>(dropout_add_mask.data_ptr()),\n                             total_tokens_q,\n                             (1.0f - dropout_prob));\n  } else {\n    apex_add_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                             static_cast<at::Half const*>(inputs_q.data_ptr()), \n                             static_cast<at::Half*>(outputs.data_ptr()), \n                             total_tokens_q);\n  }\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n\t\t   lyr_nrm_results,\n\t\t   lyr_nrm_mean,\n\t\t   lyr_nrm_invvar, \n           input_lin_q_results, \n           input_lin_kv_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n\t\t   dropout_add_mask, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   total_tokens_q    = batches_q * embed_dim;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_q_grads          = torch::empty_like(inputs_q);\n  torch::Tensor input_kv_grads         = torch::empty_like(inputs_kv);\n  torch::Tensor lyr_nrm_gamma_grads    = torch::empty_like(lyr_nrm_gamma_weights);\n  torch::Tensor lyr_nrm_beta_grads     = torch::empty_like(lyr_nrm_beta_weights);\n  torch::Tensor input_weight_q_grads   = torch::empty_like(input_weights_q);\n  torch::Tensor input_weight_kv_grads  = torch::empty_like(input_weights_kv);\n  torch::Tensor output_weight_grads    = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor dropout_add_grads         = torch::empty_like(output_grads);\n  at::Tensor output_lin_grads          = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads             = torch::empty_like(dropout_results);\n  at::Tensor input_lin_q_output_grads  = torch::empty_like(input_lin_q_results);\n  at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);\n  at::Tensor input_lin_q_grads         = torch::empty_like(inputs_q);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;\n  \n  auto q_lin_grads_ptr   = static_cast<half*>(input_lin_q_output_grads.data_ptr());\n  auto k_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr());\n  auto v_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  \n  // Dropout Add Backward  \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(dropout_add_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),\n\t\t\t\t\t\t\t total_tokens_q,\n                             (1.0 / (1.0 - dropout_prob)));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim_kv, \n                             batch_stride_kv,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t\t\t\t\t dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Input Linear Q Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_q, \n                             output_lin_q_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_q_dim, \n                             static_cast<const void*>(&beta),\n                             //static_cast<void*>(input_q_grads.data_ptr()),\n                             static_cast<void*>(input_lin_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Q Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_q_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_q.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_q_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_kv, \n                             output_lin_kv_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_kv_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_kv_dim,\n                             batches_kv, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_kv_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Fused Layer Norm Bwd with Residual Add\n  HostLayerNormGradient<half,float>(\n                             static_cast<const half*>(input_lin_q_grads.data_ptr()),\n                             static_cast<half const*>(output_grads.data_ptr()), \n                             static_cast<const float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<const float*>(lyr_nrm_invvar.data_ptr()),\n                             inputs_q,\n                             static_cast<int>(batches_q),  // n1\n                             static_cast<int>(embed_dim),  // n2\n                             static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),\n                             static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()),\n                             1.0e-5,\n                             static_cast<half*>(input_q_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_beta_grads.data_ptr())\n                                  );\n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_q_grads, \n           input_kv_grads, \n           lyr_nrm_gamma_grads, \n           lyr_nrm_beta_grads, \n           input_weight_q_grads, \n           input_weight_kv_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec_norm_add \n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/layer_norm.h",
    "content": "#include \"ATen/ATen.h\"\n#include <THC/THCDeviceUtils.cuh>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\ntemplate<typename U> __device__\nvoid cuWelfordOnlineSum(\n  const U curr,\n  U& mu,\n  U& sigma2,\n  U& count)\n{\n  count = count + U(1);\n  U delta = curr - mu;\n  U lmean = mu + delta / count;\n  mu = lmean;\n  U delta2 = curr - lmean;\n  sigma2 = sigma2 + delta * delta2;\n}\n\ntemplate<typename U> __device__\nvoid cuChanOnlineSum(\n  const U muB,\n  const U sigma2B,\n  const U countB,\n  U& mu,\n  U& sigma2,\n  U& count)\n{\n  U delta = muB - mu;\n  U nA = count;\n  U nB = countB;\n  count = count + countB;\n  U nX = count;\n  if (nX > U(0)) {\n    nA = nA / nX;\n    nB = nB / nX;\n    mu = nA*mu + nB*muB;\n    sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;\n  } else {\n    mu = U(0);\n    sigma2 = U(0);\n  }\n}\n\ntemplate<typename T, typename U> __device__\nvoid cuWelfordMuSigma2(\n  const T* __restrict__ vals,\n  const int n1,\n  const int n2,\n  const int i1,\n  U& mu,\n  U& sigma2,\n  U* buf) \n{\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  U count = U(0);\n  mu= U(0);\n  sigma2 = U(0);\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const T* lvals = vals + i1*n2;\n    int l = 4*thrx;\n    for (;  l+3 < n2;  l+=4*numx) {\n      for (int k = 0;  k < 4;  ++k) {\n        U curr = static_cast<U>(lvals[l+k]);\n        cuWelfordOnlineSum<U>(curr,mu,sigma2,count);\n      }\n    }\n    for (;  l < n2;  ++l) {\n      U curr = static_cast<U>(lvals[l]);\n      cuWelfordOnlineSum<U>(curr,mu,sigma2,count);\n    }\n    // intra-warp reductions\n    for (int l = 0;  l <= 4;  ++l) {\n      int srcLaneB = (threadIdx.x+(1<<l))&31;\n      U muB = WARP_SHFL(mu, srcLaneB);\n      U countB = WARP_SHFL(count, srcLaneB);\n      U sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      U* ubuf = (U*)buf;\n      U* ibuf = (U*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int wrt_y = threadIdx.y - offset;\n          ubuf[2*wrt_y] = mu;\n          ubuf[2*wrt_y+1] = sigma2;\n          ibuf[wrt_y] = count;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          U muB = ubuf[2*threadIdx.y];\n          U sigma2B = ubuf[2*threadIdx.y+1];\n          U countB = ibuf[threadIdx.y];\n          cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        ubuf[0] = mu;\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      mu = ubuf[0];\n      sigma2 = ubuf[1]/U(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      mu = WARP_SHFL(mu, 0);\n      sigma2 = WARP_SHFL(sigma2/U(n2), 0);\n    }\n  }\n}\n\ntemplate<> __device__\nvoid cuWelfordMuSigma2(\n  const at::Half* __restrict__ vals,\n  const int n1,\n  const int n2,\n  const int i1,\n  float& mu,\n  float& sigma2,\n  float* buf) \n{\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  float count = 0.0f;\n  mu= float(0);\n  sigma2 = float(0);\n\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const at::Half* lvals = vals + i1*n2;\n    int l = 8*thrx;\n    if ((((size_t)lvals)&3) != 0) {\n      // 16 bit alignment\n      // first thread consumes first point\n      if (thrx == 0) {\n        float curr = static_cast<float>(lvals[0]);\n        cuWelfordOnlineSum(curr,mu,sigma2,count);\n      }\n      ++l;\n    }\n    // at this point, lvals[l] are 32 bit aligned for all threads.\n    for (;  l+7 < n2;  l+=8*numx) {\n      for (int k = 0;  k < 8;  k+=2) {\n        float2 curr = __half22float2(*((__half2*)(lvals+l+k)));\n        cuWelfordOnlineSum(curr.x,mu,sigma2,count);\n\tcuWelfordOnlineSum(curr.y,mu,sigma2,count);\n      }\n    }\n    for (;  l < n2;  ++l) {\n      float curr = static_cast<float>(lvals[l]);\n      cuWelfordOnlineSum(curr,mu,sigma2,count);\n    }\n    // intra-warp reductions\n    for (int l = 0;  l <= 4;  ++l) {\n      int srcLaneB = (threadIdx.x+(1<<l))&31;\n      float muB = WARP_SHFL(mu, srcLaneB);\n      float countB = WARP_SHFL(count, srcLaneB);\n      float sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      float* ubuf = (float*)buf;\n      float* ibuf = (float*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int wrt_y = threadIdx.y - offset;\n          ubuf[2*wrt_y] = mu;\n          ubuf[2*wrt_y+1] = sigma2;\n          ibuf[wrt_y] = count;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          float muB = ubuf[2*threadIdx.y];\n          float sigma2B = ubuf[2*threadIdx.y+1];\n          float countB = ibuf[threadIdx.y];\n          cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        ubuf[0] = mu;\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      mu = ubuf[0];\n      sigma2 = ubuf[1]/float(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      mu = WARP_SHFL(mu, 0);\n      sigma2 = WARP_SHFL(sigma2/float(n2), 0);\n    }\n  }\n}\n\ntemplate<typename U> U rsqrt(U v) {\n  return U(1) / sqrt(v);\n}\ntemplate<> float rsqrt(float v) {\n  return rsqrtf(v);\n}\ntemplate<> double rsqrt(double v) {\n  return rsqrt(v);\n}\n\nnamespace {\n// This is the un-specialized struct.  Note that we prevent instantiation of this\n// struct by putting an undefined symbol in the function body so it won't compile.\n//  template <typename T>\n//  struct SharedMemory\n//  {\n//      // Ensure that we won't compile any un-specialized types\n//      __device__ T *getPointer()\n//      {\n//          extern __device__ void error(void);\n//          error();\n//          return NULL;\n//      }\n//  };\n// https://github.com/NVIDIA/apex/issues/246\ntemplate <typename T>\nstruct SharedMemory;\n\ntemplate <>\nstruct SharedMemory <float>\n{\n    __device__ float *getPointer()\n    {\n        extern __shared__ float s_float[];\n        return s_float;\n    }\n};\n\ntemplate <>\nstruct SharedMemory <double>\n{\n    __device__ double *getPointer()\n    {\n        extern __shared__ double s_double[];\n        return s_double;\n    }\n};\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuApplyLayerNorm(\n  T* __restrict__ output_vals,\n  U* __restrict__ mean,\n  U* __restrict__ invvar,\n  const T* __restrict__ vals,\n  const int n1,\n  const int n2,\n  const U epsilon,\n  const T* __restrict__ gamma,\n  const T* __restrict__ beta\n  ) \n{\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensors are contiguous\n  //\n  for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer();\n    U mu,sigma2;\n    cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);\n    const T* lvals = vals + i1*n2;\n    T* ovals = output_vals + i1*n2;\n    U c_invvar = rsqrt(sigma2 + epsilon);\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != NULL && beta != NULL) {\n      for (int i = thrx;  i < n2;  i+=numx) {\n        U curr = static_cast<U>(lvals[i]);\n        ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];\n      }\n    } else {\n      for (int i = thrx;  i < n2;  i+=numx) {\n        U curr = static_cast<U>(lvals[i]);\n        ovals[i] = static_cast<T>(c_invvar * (curr - mu));\n      }\n    }\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      mean[i1] = mu;\n      invvar[i1] = c_invvar;\n    }\n  }\n}\n\ntemplate<typename T, typename U> __device__\nvoid cuLoadWriteStridedInputs(\n    const int i1_block,\n    const int thr_load_row_off,\n    const int thr_load_col_off,\n    const int i2_off,\n    const int row_stride,\n    U* warp_buf1,\n    U* warp_buf2,\n    const T* input,\n    const T* dout,\n    const int i1_end,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar\n    )\n{\n  int i1 = i1_block+thr_load_row_off;\n  if (i1 < i1_end) {\n    U curr_mean = mean[i1];\n    U curr_invvar = invvar[i1];\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1*n2+i2;\n      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;\n      if (i2<n2) {\n        U curr_input = static_cast<U>(input[load_idx]);\n\tU curr_dout = static_cast<U>(dout[load_idx]);\n\twarp_buf1[write_idx] = curr_dout;\n\twarp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;\n      } else {\n        warp_buf1[write_idx] = U(0);\n        warp_buf2[write_idx] = U(0);\n      }\n    }\n  } else {\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;\n      warp_buf1[write_idx] = U(0);\n      warp_buf2[write_idx] = U(0);\n    }\n  }\n}\n\ntemplate<typename T, typename U> __device__\nvoid cuLoadAddStridedInputs(\n    const int i1_block,\n    const int thr_load_row_off,\n    const int thr_load_col_off,\n    const int i2_off,\n    const int row_stride,\n    U* warp_buf1,\n    U* warp_buf2,\n    const T* input,\n    const T* dout,\n    const int i1_end,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar\n    )\n{\n  int i1 = i1_block+thr_load_row_off;\n  if (i1 < i1_end) {\n    U curr_mean = mean[i1];\n    U curr_invvar = invvar[i1];\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1*n2+i2;\n      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;\n      if (i2<n2) {\n        U curr_input = static_cast<U>(input[load_idx]);\n\tU curr_dout = static_cast<U>(dout[load_idx]);\n\twarp_buf1[write_idx] += curr_dout;\n\twarp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuComputePartGradGammaBeta(\n    const T* __restrict__ dout,\n    const T* __restrict__ input,\n    const int n1,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar,\n    U epsilon,\n    U* part_grad_gamma,\n    U* part_grad_beta)\n{\n    const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);\n    const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;\n    const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;\n    const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;\n    const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;\n    const int row_stride = blockDim.x+1;\n    const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);\n    const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;\n    const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements\n    U* warp_buf1 = (U*)buf;\n    U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;\n    // compute partial sums from strided inputs\n    // do this to increase number of loads in flight\n    cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);\n    for (int i1_block = i1_beg+blockDim.y*blockDim.y;  i1_block < i1_end;  i1_block+=blockDim.y*blockDim.y) {\n      cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);\n    }\n    __syncthreads();\n    // inter-warp reductions\n    // sum within each warp\n    U acc1 = U(0);\n    U acc2 = U(0);\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int row1 = threadIdx.y + k*blockDim.y;\n      int idx1 = row1*row_stride + threadIdx.x;\n      acc1 += warp_buf1[idx1];\n      acc2 += warp_buf2[idx1];\n    }\n    warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;\n    warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;\n    __syncthreads();\n    // sum all warps\n    for (int offset = blockDim.y/2;  offset > 1;  offset /= 2) {\n      if (threadIdx.y < offset) {\n        int row1 = threadIdx.y;\n\tint row2 = threadIdx.y + offset;\n\tint idx1 = row1*row_stride + threadIdx.x;\n\tint idx2 = row2*row_stride + threadIdx.x;\n\twarp_buf1[idx1] += warp_buf1[idx2];\n\twarp_buf2[idx1] += warp_buf2[idx2];\n      }\n      __syncthreads();\n    }\n    int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n    if (threadIdx.y == 0 && i2 < n2) {\n      int row1 = threadIdx.y;\n      int row2 = threadIdx.y + 1;\n      int idx1 = row1*row_stride + threadIdx.x;\n      int idx2 = row2*row_stride + threadIdx.x;\n      part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];\n      part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];\n    }\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuComputeGradGammaBeta(\n    const U* part_grad_gamma,\n    const U* part_grad_beta,\n    const int part_size,\n    const int n1,\n    const int n2,\n    T* grad_gamma,\n    T* grad_beta)\n{\n    // sum partial gradients for gamma and beta\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer(); \n    int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i2 < n2) {\n      // each warp does sequential reductions until reduced part_size is num_warps\n      int num_warp_reductions = part_size / blockDim.y;\n      U sum_gamma = U(0);\n      U sum_beta = U(0);\n      const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;\n      const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;\n      for (int warp_offset = 0;  warp_offset < num_warp_reductions;  ++warp_offset) {\n        sum_gamma += part_grad_gamma_ptr[warp_offset*n2];\n        sum_beta += part_grad_beta_ptr[warp_offset*n2];\n      }\n      // inter-warp reductions\n      const int nbsize3 = blockDim.x * blockDim.y / 2;\n      for (int offset = blockDim.y/2;  offset >= 1;  offset /= 2) {\n        // top half write to shared memory\n        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n          buf[write_idx] = sum_gamma;\n          buf[write_idx+nbsize3] = sum_beta;\n        }\n        __syncthreads();\n        // bottom half sums\n        if (threadIdx.y < offset) {\n          const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;\n          sum_gamma += buf[read_idx];\n          sum_beta += buf[read_idx+nbsize3];\n        }\n        __syncthreads();\n      }\n      // write out fully summed gradients\n      if (threadIdx.y == 0) {\n        grad_gamma[i2] = sum_gamma;\n        grad_beta[i2] = sum_beta;\n      }\n    }\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuComputeGradInput(\n    const T* __restrict__ dout,\n    const T* __restrict__ dout_resid,\n    const T* __restrict__ input,\n    const int n1,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar,\n    U epsilon,\n    const T* gamma,\n    T* grad_input)\n{\n  for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    U sum_loss1 = U(0);\n    U sum_loss2 = U(0);\n    const U c_mean = mean[i1];\n    const U c_invvar = invvar[i1];\n    const T* k_input = input + i1*n2;\n    const T* k_dout = dout + i1*n2;\n    const T* k_dout_resid = dout_resid + i1*n2;\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != NULL) {\n      int l = 4*thrx;\n      for (;  l+3 < n2;  l+=4*numx) {\n        for (int k = 0;  k < 4;  ++k) {\n          const U c_h = static_cast<U>(k_input[l+k]);\n          const U c_loss = static_cast<U>(k_dout[l+k]);\n          sum_loss1 += c_loss * static_cast<U>(gamma[l+k]);\n          sum_loss2 += c_loss * static_cast<U>(gamma[l+k]) * (c_h - c_mean) * c_invvar;\n        }\n      }\n      for (;  l < n2;  ++l) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        sum_loss1 += c_loss * static_cast<U>(gamma[l]);\n        sum_loss2 += c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;\n      }\n    } else {\n      int l = 4*thrx;\n      for (;  l+3 < n2;  l+=4*numx) {\n        for (int k = 0;  k < 4;  ++k) {\n          const U c_h = static_cast<U>(k_input[l+k]);\n          const U c_loss = static_cast<U>(k_dout[l+k]);\n          sum_loss1 += c_loss;\n          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n        }\n      }\n      for (;  l < n2;  ++l) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        sum_loss1 += c_loss;\n        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n      }\n    }\n    // intra-warp reductions\n    for (int mask = blockDim.x/2;  mask > 0;  mask /= 2) {\n      sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);\n      sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);\n    }\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      SharedMemory<U> shared;\n      U* buf = shared.getPointer(); \n      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n          buf[2*wrt_i] = sum_loss1;\n          buf[2*wrt_i+1] = sum_loss2;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.y < offset) {\n          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;\n          sum_loss1 += buf[2*read_i];\n          sum_loss2 += buf[2*read_i+1];\n        }\n        __syncthreads();\n      }\n      if (threadIdx.y == 0) {\n        buf[2*threadIdx.x] = sum_loss1;\n        buf[2*threadIdx.x+1] = sum_loss2;\n      }\n      __syncthreads();\n      if (threadIdx.y !=0) {\n        sum_loss1 = buf[2*threadIdx.x];\n        sum_loss2 = buf[2*threadIdx.x+1];\n      } \n    }\n    // all threads now have the two sums over l\n    U fH = (U)n2;\n    U term1 = (U(1) / fH) * c_invvar;\n    T* k_grad_input = grad_input + i1*n2;\n    if (gamma != NULL) {\n      for (int l = thrx;  l < n2;  l+=numx) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        const T c_resid= static_cast<T>(k_dout_resid[l]);\n        U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);\n        f_grad_input -= sum_loss1;\n        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid;\n      }\n    } else {\n      for (int l = thrx;  l < n2;  l+=numx) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        const T c_resid= static_cast<T>(k_dout_resid[l]);\n        U f_grad_input = fH * c_loss;\n        f_grad_input -= sum_loss1;\n        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid;\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename U> \nvoid HostApplyLayerNorm(\n    T* output,\n    U* mean,\n    U* invvar,\n    const T* input,\n    int n1,\n    int n2,\n    double epsilon,\n    const T* gamma,\n    const T* beta\n    )\n{\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n    const dim3 threads(32,4,1);\n    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n    const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);\n    int nshared = \n        threads.y > 1 ? \n\t    threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : \n\t    0;\n    cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(\n\t\t    output,\n\t\t    mean,\n\t\t    invvar,\n\t\t    input,\n\t\t    n1,n2,\n\t\t    U(epsilon),\n            gamma,beta);\n}\n\ntemplate<typename T, typename U> \nvoid HostLayerNormGradient(\n    const T* dout,\n    const T* dout_resid,\n    const U* mean,\n    const U* invvar,\n    const at::Tensor& input,\n    int n1,\n    int n2,\n    const T* gamma,\n    const T* beta,\n    double epsilon,\n    T* grad_input,\n    T* grad_gamma,\n    T* grad_beta\n    )\n{\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    if (gamma != NULL && beta != NULL) {\n      // compute grad_gamma(j) and grad_beta(j)\n      const int part_size = 16;\n      const dim3 threads2(32,4,1);\n      const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);\n      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);\n      const int nshared2_b = threads2.x * threads2.y * sizeof(U);\n      const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;\n      at::Tensor part_grad_gamma = at::empty({part_size,n2}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));\n      at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);\n      cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(\n\t\t      dout,\n\t\t      static_cast<T*>(input.data_ptr()),\n\t\t      n1,n2,\n\t\t      mean,\n\t\t      invvar,\n\t\t      U(epsilon),\n\t\t      static_cast<U*>(part_grad_gamma.data_ptr()),\n\t\t      static_cast<U*>(part_grad_beta.data_ptr()));\n\n      const dim3 threads3(32,8,1);\n      const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);\n      const int nshared3 = threads3.x * threads3.y * sizeof(U);\n      cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(\n\t\t      static_cast<U*>(part_grad_gamma.data_ptr()),\n\t\t      static_cast<U*>(part_grad_beta.data_ptr()),\n\t\t      part_size,\n\t\t      n1,n2,\n\t\t      grad_gamma,\n\t\t      grad_beta);\n    }\n\n    // compute grad_input\n    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n    const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);\n    const dim3 threads1(32,4,1);\n    int nshared =\n\t    threads1.y > 1 ?\n\t    threads1.y*threads1.x*sizeof(U) :\n\t    0;\n    cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(\n            dout,\n\t    dout_resid,\n            static_cast<T*>(input.data_ptr()),\n            n1,n2,\n            mean,\n            invvar,\n            U(epsilon),\n            gamma,\n            grad_input);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               const uint8_t *padding_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t       bool \t\t\t\tuse_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(input.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(\n                                 is_training,\n                                 heads, \n                                 input, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\ntorch::Tensor bwd(\n\t\t               bool use_mask,\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& padding_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n//  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n\t\t                 heads,\n                                 output_grads,\n                                 softmax_results, \n                                 dropout_mask, \n                                 use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace mask_softmax_dropout\n} // end namespace fused_softmax\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, \"Self Multihead Attention masked softmax dropout -- Forward.\");\n  m.def(\"backward\", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, \"Self Multihead Attention masked softmax dropout -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"softmax.h\"\n#include \"dropout.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n\t\t\t       bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = input.size(0);\n  const int   sequences      = attn_batches / heads;\n  const int   q_seq_len      = input.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = input.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n  }\n\n\n  if (is_training) {\n    //use at:: function so that C++ version generates the same random mask as python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n\n  return {\n           dropout_results,  \n           dropout_mask, \n           softmax_results\n         };\n}\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results, \n                               torch::Tensor const& dropout_mask,\n                               const uint8_t  *padding_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = output_grads.size(0);\n  const int   q_seq_len      = output_grads.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n//  torch::Tensor input_grads         = torch::empty_like(output_grads);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  if (padding_mask == nullptr) {\n      dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(\n                             static_cast<half*>(output_grads.data_ptr()), \n                             static_cast<half*>(output_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len, stream);\n  } else{\n      dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(\n                             static_cast<half*>(output_grads.data_ptr()), \n                             static_cast<half*>(output_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(padding_mask),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n\t\t\t     heads, stream); \n  \n  }\n//backward pass is completely in-place\n  return output_grads;\n}\n}\n}\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/philox.h",
    "content": "#pragma once\n//Philox CUDA. \n\nclass Philox {\npublic:\n  __device__ inline Philox(unsigned long long seed,\n                           unsigned long long subsequence,\n                           unsigned long long offset) {\n    key.x = (unsigned int)seed;\n    key.y = (unsigned int)(seed >> 32);\n    counter = make_uint4(0, 0, 0, 0);\n    counter.z = (unsigned int)(subsequence);\n    counter.w = (unsigned int)(subsequence >> 32);\n    STATE = 0;\n    incr_n(offset / 4);\n  }\n  __device__ inline uint4 operator()() {\n    if(STATE == 0) {\n      uint4 counter_ = counter;\n      uint2 key_ = key;\n      //7-round philox\n      for(int i = 0; i < 6; i++) {\n        counter_ = single_round(counter_, key_);\n        key_.x += (kPhilox10A); key_.y += (kPhilox10B);\n      }\n      output = single_round(counter_, key_);\n      incr();\n    }\n    //return a float4 directly\n    //unsigned long ret;\n    //switch(STATE) {\n    //  case 0: ret = output.x; break;\n    //  case 1: ret = output.y; break;\n    //  case 2: ret = output.z; break;\n    //  case 3: ret = output.w; break;\n    //}\n    //STATE = (STATE + 1) % 4;\n    return output;\n  }\nprivate:\n  uint4 counter;\n  uint4 output;\n  uint2 key;\n  unsigned int STATE;\n  __device__ inline void incr_n(unsigned long long n) {\n    unsigned int nlo = (unsigned int)(n);\n    unsigned int nhi = (unsigned int)(n >> 32);\n    counter.x += nlo;\n    if (counter.x < nlo)\n      nhi++;\n    counter.y += nhi;\n    if (nhi <= counter.y)\n      return;\n    if (++counter.z)\n      return;\n    ++counter.w;\n  }\n  __device__ inline void incr() {\n    if (++counter.x)\n      return;\n    if (++counter.y)\n      return;\n    if (++counter.z)\n      return;\n    ++counter.w;\n  }\n  __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,\n                                    unsigned int *result_high) {\n    *result_high = __umulhi(a, b);\n    return a*b;\n  }\n  __device__ inline uint4 single_round(uint4 ctr, uint2 key) {\n    unsigned int hi0;\n    unsigned int hi1;\n    unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);\n    unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);\n    uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};\n    return ret;\n  }\n  static const unsigned long kPhilox10A = 0x9E3779B9;\n  static const unsigned long kPhilox10B = 0xBB67AE85;\n  static const unsigned long kPhiloxSA = 0xD2511F53;\n  static const unsigned long kPhiloxSB = 0xCD9E8D57;\n};\n// Inverse of 2^32.\n#define M_RAN_INVM32 2.3283064e-10f\n__device__  __inline__ float4 uniform4(uint4 x) {\n    return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);\n\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace self {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t\t\t\t\t   bool \t\t\t\tuse_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs, \n                                 input_weights, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()    == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()    == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n  \n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_results, \n                                 inputs, \n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self::cublas_gemmex::fwd, \"Self Multihead Attention Forward.\");\n  m.def(\"backward\", &multihead_attn::self::cublas_gemmex::bwd, \"Self Multihead Attention Backward.\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace self_bias {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               //torch::Tensor const& input_biases,\n                               //torch::Tensor const& output_biases,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t\t\t\t\t   bool \t\t\t\tuse_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs, \n                                 input_weights, \n                                 output_weights, \n                                 input_biases, \n                                 output_biases, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()    == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()    == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_results, \n                                 inputs, \n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self_bias::cublas_gemmex::fwd, \"Self Multihead Attention with Bias -- Forward.\");\n  m.def(\"backward\", &multihead_attn::self_bias::cublas_gemmex::bwd, \"Self Multihead Attention with Bias -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n#include <cuda_fp16.h>\n\nnamespace multihead_attn {\nnamespace self_bias_additive_mask {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const half*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                              // torch::Tensor const& softmax_results,\n                               torch::Tensor const& bmm1_results,\n                               torch::Tensor const& pad_mask,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               //torch::Tensor const& input_biases,\n                               //torch::Tensor const& output_biases,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t\t\t\t\t   bool \t\t\t\tuse_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(use_mask                                                  , \"no mask is not supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Half, \"Only Half is supported\");\n  }\n\n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs, \n                                 input_weights, \n                                 output_weights, \n                                 input_biases, \n                                 output_biases, \n                                 use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& bmm1_results,\n                               torch::Tensor const& pad_mask,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()    == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()    == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n\t\t\t\t bmm1_results,\n\t\t\t\t pad_mask, \n                                 input_lin_results, \n                                 inputs, \n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, \"Self Multihead Attention with Bias -- Forward.\");\n  m.def(\"backward\", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, \"Self Multihead Attention with Bias -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self_bias_additive_mask {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const half*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta_zero       = 0.0;\n  const float beta_one           = 1.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor bmm1_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());\n  void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  input_lin_results.copy_(input_biases);\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta_zero, \n                             static_cast<half*>(bmm1_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n  // Padded Softmax\n  bool softmax_success = false;\n  if (is_training) {\n      softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(\n                           reinterpret_cast<half*>(dropout_results_ptr),\n                           (is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,\n                           reinterpret_cast<const half*>(bmm1_results_ptr),\n                           pad_mask,\n      \t\t           attn_batches*q_seq_len*q_seq_len,\n                           k_seq_len,\n                           k_seq_len,\n                           attn_batches*q_seq_len,\n                           attn_batches*q_seq_len/sequences, \n      \t\t           1.0f-dropout_prob,\n\t\t           stream);\n  } else {\n      softmax_success = dispatch_additive_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function\n                             reinterpret_cast<const half*>(bmm1_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(dropout_results.data_ptr()), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta_zero, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  outputs.copy_(output_biases);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n           input_lin_results,  \n           bmm1_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& bmm1_results,\n                               torch::Tensor const& pad_mask,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads         = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads  = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half* const>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(bmm1_results.data_ptr()),\n                             reinterpret_cast<half const*>(pad_mask.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n\t\t\t     attn_batches*q_seq_len/sequences,\n                             attn_batches*q_seq_len,\n\t\t\t     stream);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n\t\t\t     static_cast<const void*>(input_lin_output_grads.data_ptr()),\n                             //static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_grads, \n           input_weight_grads, \n           output_weight_grads,\n           input_bias_grads, \n           output_bias_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self_bias {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta_zero       = 0.0;\n  const float beta_one           = 1.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  input_lin_results.copy_(input_biases);\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta_zero, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n\n\n  if (is_training) {\n    //use at:: function so that C++ version generates the same random mask as python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta_zero, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  outputs.copy_(output_biases);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n           input_lin_results,  \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads         = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads  = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len, stream);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n\t\t\t     static_cast<const void*>(input_lin_output_grads.data_ptr()),\n                             //static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_grads, \n           input_weight_grads, \n           output_weight_grads,\n           input_bias_grads, \n           output_bias_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(softmax_results.data_ptr()),\n                               static_cast<at::Half*>(dropout_results.data_ptr()),\n                               static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                               dropout_elems,\n                               (1.0f - dropout_prob));\n  }\n \n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_lin_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_grads         = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads  = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n  \n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n                             static_cast<at::Half*>(matmul2_grads.data_ptr()),\n                             static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_grads, \n           input_weight_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace self_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n                               bool                 use_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()                 == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()   == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights.dim()          == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()         == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()                == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs,\n                                 lyr_nrm_gamma_weights,\n                                 lyr_nrm_beta_weights,\n                                 input_weights, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim()     == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_mean.dim()          == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_invvar.dim()        == 1, \"expected 1D tensor\");\n  AT_ASSERTM(inputs.dim()                == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights.dim()         == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()        == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_add_mask.dim()      == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()          == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType()     == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_mean.type().scalarType()          == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(lyr_nrm_invvar.type().scalarType()        == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(inputs.type().scalarType()                == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()         == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()          == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  AT_ASSERTM(dropout_add_mask.type().scalarType()      == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  \n  return bwd_cuda(heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_results, \n                                 lyr_nrm_results,\n                                 lyr_nrm_mean,\n                                 lyr_nrm_invvar,\n                                 inputs, \n\t\t\t\t\t\t\t     lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t\t lyr_nrm_beta_weights,\n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_add_mask,\n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self_norm_add \n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self_norm_add::cublas_gemmex::fwd, \"Self Multihead Attention Plus Layer Norm and Residual Add Forward.\");\n  m.def(\"backward\", &multihead_attn::self_norm_add::cublas_gemmex::bwd, \"Self Multihead Attention Plus Layer Norm and Residual Add Backward.\");\n}\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   total_tokens   = batches * embed_dim;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options                = inputs.options().requires_grad(false);\n  auto lyr_nrm_options            = act_options.dtype(torch::kFloat32);\n  auto mask_options               = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor lyr_nrm_mean      = torch::empty({batches}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_invvar    = torch::empty({batches}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_results   = torch::empty_like(inputs, act_options);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor output_lin_results= torch::empty_like(inputs, act_options);\n  torch::Tensor dropout_add_mask  = torch::empty_like(inputs, mask_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Layer Norm\n  HostApplyLayerNorm<at::Half,float>(\n                             static_cast<at::Half*>(lyr_nrm_results.data_ptr()),\n                             static_cast<float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<float*>(lyr_nrm_invvar.data_ptr()),\n                             static_cast<const at::Half*>(inputs.data_ptr()),\n                             static_cast<int>(batches),   // n1\n                             static_cast<int>(embed_dim), // n2\n                             1.0e-5,\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));\n\n  // Input Linear Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             //static_cast<const void*>(inputs.data_ptr()),\n                             static_cast<const void*>(lyr_nrm_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(softmax_results.data_ptr()), \n                             static_cast<at::Half*>(dropout_results.data_ptr()), \n                             static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             //static_cast<const half*>(dropout_results.data_ptr()), \n                             k_seq_len,  \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()),  \n                             head_dim*attn_batches,  \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // End-of-block Dropout-Add \n  if (is_training) {\n    apex_dropout_add_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                               static_cast<at::Half const*>(inputs.data_ptr()), \n                               static_cast<at::Half*>(outputs.data_ptr()), \n                               static_cast<uint8_t*>(dropout_add_mask.data_ptr()),\n                               total_tokens,\n                               (1.0f - dropout_prob));\n  } else {\n    apex_add_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                               static_cast<at::Half const*>(inputs.data_ptr()), \n                               static_cast<at::Half*>(outputs.data_ptr()), \n                               total_tokens);\n  }\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           lyr_nrm_results,\n\t\t   lyr_nrm_mean,\n           lyr_nrm_invvar, \n           input_lin_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results,\n           dropout_add_mask, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   total_tokens   = batches * embed_dim;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_grads            = torch::empty_like(inputs);\n  torch::Tensor lyr_nrm_gamma_grads    = torch::empty_like(lyr_nrm_gamma_weights);\n  torch::Tensor lyr_nrm_beta_grads     = torch::empty_like(lyr_nrm_beta_weights);\n  torch::Tensor input_weight_grads     = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads    = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  torch::Tensor dropout_add_grads      = torch::empty_like(output_grads);\n  torch::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  torch::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n  torch::Tensor input_lin_grads        = torch::empty_like(inputs);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n  \n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Dropout Add Backward  \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(dropout_add_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),\n       \t\t\t\t\t\t total_tokens,\n                             (1.0 / (1.0 - dropout_prob)));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t\t\t\t\t dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             //static_cast<void*>(input_grads.data_ptr()),\n                             static_cast<void*>(input_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             //static_cast<const void*>(inputs.data_ptr()),\n                             static_cast<const void*>(lyr_nrm_results.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Fused Layer Norm Bwd with Residual Add\n  HostLayerNormGradient<half,float>(\n                             static_cast<const half*>(input_lin_grads.data_ptr()),\n                             static_cast<half const*>(output_grads.data_ptr()), \n                             static_cast<const float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<const float*>(lyr_nrm_invvar.data_ptr()),\n                             inputs,\n                             static_cast<int>(batches),   // n1\n                             static_cast<int>(embed_dim), // n2\n                             static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),\n                             static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()),\n                             1.0e-5,\n                             static_cast<half*>(input_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_beta_grads.data_ptr())\n                                   );\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n\t\t   input_grads, \n           lyr_nrm_gamma_grads, \n           lyr_nrm_beta_grads, \n           input_weight_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self_norm_add \n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/softmax.h",
    "content": "#pragma once\n#include <ATen/CUDAGeneratorImpl.h>\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n#include <curand_kernel.h>\n#include \"philox.h\"\n \n#include <assert.h>\n#include <cfloat>\n#include <limits>\n#include <stdint.h>\n#include <cuda_fp16.h>\n#include <cmath>\n \nnamespace {\n    template <typename Datatype, int ELEMENTS_PER_LDG>\n    __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);\n \n    template <>\n    __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }\n \n    template <>\n    __device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }\n \n    template <>\n    __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } \n    template <>\n    __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }\n    \n    template <>\n    __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }\n   \n    template <typename Datatype, int ELEMENTS_PER_LDG>\n    __device__ __inline__ void apply_mask(Datatype *dst, Datatype value, const uint8_t *src);\n    \n    template <>\n    __device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, const uint8_t *src) {\n      if (*src == 1) { *dst = value; }\n    }\n    template <typename Datatype, int ELEMENTS_PER_LDG>\n    __device__ __inline__ void apply_additive_mask(Datatype *dst, const Datatype *additive_mask);\n    template <>\n    __device__ __inline__ void apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {\n      *dst += *additive_mask; \n    }\n    template <>\n    __device__ __inline__ void apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {\n      *dst += *additive_mask;\n      *(dst+1) += *(additive_mask+1);\n      *(dst+2) += *(additive_mask+2);\n      *(dst+3) += *(additive_mask+3);}    \n} // namespace anonymous\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp Softmax forward\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batch_size, int stride, int element_count)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n \n    src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n            }\n \n            if (element_index < batch_element_count) {\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + i * element_count + it * WARP_SIZE);\n            }\n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing softmax_forward_func = void(*)(input_t *dst, const output_t *src, int batch_size, int stride, int element_count);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>\n__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)\n{\n \n    assert(ELEMENTS_PER_LDG_STG==4);\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n    int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;\n    acc_t pinv = acc_t(1)/p;\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n     \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n    //vectorize if element_count is multiple of 4, else don't vectorize\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n    dropout_mask += thread_offset;\n    \n    // load data from global memory\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const half* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n    \t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n            }\n    \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()\n            } \n    \n        }\n    }\n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n    auto seeds = at::cuda::philox::unpack(philox_args);\n    Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));     \n    uint8_t rands[WARP_BATCH][WARP_ITERATIONS];\n    float4 rand_num;\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n\t#pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it+=ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n\t\trand_num = uniform4(ph());\n                rands[i][it] = (rand_num.x <= p) > 0.5;  \n                rands[i][it+1] = (rand_num.y <= p) > 0.5;\n                rands[i][it+2] = (rand_num.z <= p) > 0.5;\n                rands[i][it+3] = (rand_num.w <= p) > 0.5;\n                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);\n\t    }\n        }\n    }\n\n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                output_t out[ELEMENTS_PER_LDG_STG];\n                #pragma unroll\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = rands[i][it+element] * (pinv * (elements[i][it + element] / sum[i]));\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n    \n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>\n__global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n    int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;\n    acc_t pinv = acc_t(1)/p;\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n     \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n    //vectorize if element_count is multiple of 4, else don't vectorize\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n\n    int thread_offset =  first_batch * stride + local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n    dropout_mask += thread_offset;\n    \n    // load data from global memory\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + local_idx;\n        const half* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += 1) {\n            int element_index = local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < 1;++element) {\n    \t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n            }\n    \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, 1>(&elements_input[i][it], src + itr_idx);\n                apply_additive_mask<input_t, 1>(&elements_input[i][it], curr_mask + itr_jmp); \n            } \n    \n        }\n    }\n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n    curandStatePhilox4_32_10_t state;\n    auto seeds = at::cuda::philox::unpack(philox_args);\n    curand_init(\n      std::get<0>(seeds),\n      tid,\n      std::get<1>(seeds),\n      &state);\n     \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += 1) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                output_t out[1];\n                acc_t softmax_out[1];\n                uint8_t dropout_mask_temp[1];\n                //generate a vector of random numbers here \n                float rand = curand_uniform(&state);\n                float *rand_ptr = (float*)(&rand);    \n                #pragma unroll\n                for (int element = 0;element < 1;++element) {\n    \t        softmax_out[element] = (elements[i][it + element] / sum[i]);\t\n                    rand_ptr[element] = rand_ptr[element] <= p;       \n                    out[element] = rand_ptr[element] * pinv * softmax_out[element];\n    \t            dropout_mask_temp[element] = rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f \n                }\n                copy_vector<output_t, 1>(dst + i * element_count + it * WARP_SIZE, out);\n                copy_vector<uint8_t, 1>(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp);\n    \n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t>\nusing additive_masked_softmax_dropout_forward_func = void(*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride,  at::PhiloxCudaState philox_args, float p);\n\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n    bool flag_vec4 = (element_count % 4 == 0); \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n\tif (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2,4,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n\tif (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,8,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,16,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,32,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    case 11: // 2048\n        if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,64,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,64,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n\n\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int totalElements, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, cudaStream_t streamid)// p is the probability to keep, not drop\n{\n\t\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 2048) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n\tc10::optional<at::Generator> gen_;\n        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());\n        int64_t counter_offset = (totalElements/(blocks*threads_per_block)+1);\n        at::PhiloxCudaState rng_engine_inputs;\n\t{\n          std::lock_guard<std::mutex> lock(gen->mutex_);\n\t  rng_engine_inputs = gen->philox_cuda_state(counter_offset);\n        }\n \n        // compute launch size\n        dim3 threads(warp_size, warps_per_block, 1);\n         \n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride, rng_engine_inputs, p);\n        return true;\n    }\n    return false;\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const half* curr_mask    = pad_mask + pad_thread_offset;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n\t\t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n            }\n \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                //apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], \n                //                                          (__half)-std::numeric_limits<float>::infinity(), \n                //                                          curr_mask + itr_jmp);\n                elements_input[i][it] += *(curr_mask + itr_jmp);\n\t    } \n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing additive_masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const half *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_additive_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        additive_masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax_stream(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, cudaStream_t streamid)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n        additive_masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n\n\n\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const uint8_t* curr_mask    = pad_mask + pad_thread_offset;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n            }\n \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], \n                                                          (__half)-std::numeric_limits<float>::infinity(), \n                                                          curr_mask + itr_jmp);\n            }\n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, masked_softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int mod_seq_len)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) % mod_seq_len) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const uint8_t* curr_mask    = pad_mask + pad_thread_offset;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n            }\n \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], \n                                                          (__half)-std::numeric_limits<float>::infinity(), \n                                                          curr_mask + itr_jmp);\n            }\n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing time_masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int mod_seq_len);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_time_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, time_masked_softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int mod_seq_len)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        time_masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_time_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len);\n        return true;\n    }\n    return false;\n}\n\nint log2_ceil_native(int value) {\n    int log2_value = 0;\n    while ((1 << log2_value) < value) ++log2_value;\n    return log2_value;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)\n{\n#if CUDA_VERSION >= 9000\n    return __shfl_xor_sync(mask, value, laneMask, width);\n#else\n    return __shfl_xor(value, laneMask, width);\n#endif\n}\n\ntemplate <typename acc_t, int WARP_BATCH, int WARP_SIZE>\n__device__ __forceinline__ void warp_reduce_sum(acc_t* sum) {\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;  i < WARP_BATCH;  ++i) {\n            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);\n            sum[i] = sum[i] + b;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp softmax backward functions as fused variants of at::softmax_backward_data function\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\n//softmax backward data function is taken from native pytorch, elementwise mul is fused in the epolog, as well as masking and scaling for fusing dropout\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size, int stride, int element_count, int heads)\n{\n    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.\n    constexpr int next_power_of_two = 1 << log2_elements;\n    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n\n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n    mask += thread_offset;\n\n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n\n    // load data from global memory\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];\n                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];\n            } else {\n                grad_reg[i][it] = acc_t(0);\n                output_reg[i][it] = acc_t(0);\n            }\n        }\n    }\n\n    acc_t sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        sum[i] = grad_reg[i][0]; \n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            sum[i] += grad_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n\t\tint total_ind = thread_offset + i*element_count + it*WARP_SIZE;\n\t\tint pad_mask_ind =  element_count*(total_ind/(heads * element_count * element_count)) + total_ind%element_count;\n\t\tuint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind];\n\t\tif (pad_mask_element == 0) gradInput[i*element_count+it*WARP_SIZE] = 0;\n\t\telse {\n                  if (is_log_softmax) {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n                  } else {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n                  }\n\t\t}\n            }\n        }\n    }\n}\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 1: // 2\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 2: // 4\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 3: // 8\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 4: // 16\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 5: // 32\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 6: // 64\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 7: // 128\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 8: // 256\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 9: // 512\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 10: // 1024\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads, cudaStream_t streamid)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 1: // 2\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 2: // 4\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 3: // 8\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 4: // 16\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 5: // 32\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 6: // 64\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 7: // 128\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 8: // 256\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 9: // 512\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 10: // 1024\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count)\n{\n    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.\n    constexpr int next_power_of_two = 1 << log2_elements;\n    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n\n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n    mask += thread_offset;\n\n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n\n    // load data from global memory\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];\n                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];\n            } else {\n                grad_reg[i][it] = acc_t(0);\n                output_reg[i][it] = acc_t(0);\n            }\n        }\n    }\n\n    acc_t sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        sum[i] = grad_reg[i][0]; \n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            sum[i] += grad_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                if (is_log_softmax) {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n                } else {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n                }\n            }\n        }\n    }\n}\n\n\n\n\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count)\n{\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n    //vectorize if a row length is multiple of 4\n    int flag_vec4 = element_count & 3 == 0;\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS] ;\n\n    // the first element to process by the current thread\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    \n    grad += thread_offset;\n    softmax_input += thread_offset;\n    gradInput += thread_offset;\n    mask += thread_offset;\n    \n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n    \n    // load data from global memory\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const input_t* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n    \n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n    \t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n    \t        grad_reg[i][it+element] = acc_t(0);\n            }\n    \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], softmax_input + itr_idx);\n                apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()\n                uint8_t mask_temp[ELEMENTS_PER_LDG_STG];\n                input_t grad_temp[ELEMENTS_PER_LDG_STG];\n                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0], mask + itr_idx);\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0], grad + itr_idx);\n                #pragma unroll\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    grad_reg[i][it+element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale );\n                }\n            } \n    \n        }\n    }\n    // load data from global memory\n \n    // convert input_t to acc_t\n    // TODO : remove this, input is already acc_t type in register\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS] ;\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n\n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it ++) {\n\t   elements[i][it] = elements[i][it] / sum[i]; \n           grad_reg[i][it] = grad_reg[i][it] * elements[i][it];\n\t}\n    }\n\n    acc_t grad_sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        grad_sum[i] = grad_reg[i][0]; \n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            grad_sum[i] += grad_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(grad_sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n\t            output_t grad_input_reg[ELEMENTS_PER_LDG_STG];\n                #pragma unroll\n\t            for (int element=0; element<ELEMENTS_PER_LDG_STG; element++) {\n                    if (is_log_softmax) {\n                        grad_input_reg[element] = (grad_reg[i][it+element] - std::exp(elements[i][it+element]) * grad_sum[i]);\n                    } else {\n                        grad_input_reg[element] = (grad_reg[i][it+element] - elements[i][it+element] * grad_sum[i]);\n                    }\n\t             \n\t            }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);\n            }\n        }\n    }\n}\n\n\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nusing masked_scale_softmax_warp_backward_recompute_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count);\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nbool masked_scale_softmax_warp_backward_recompute_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n    bool flag_vec4 = (element_count % 4 == 0); \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,1,1, is_log_softmax>;\n        break;\n    case 1: // 2\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,2,1, is_log_softmax>;\n        break;\n    case 2: // 4\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,4,1, is_log_softmax>;\n        break;\n    case 3: // 8\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,8,1, is_log_softmax>;\n        break;\n    case 4: // 16\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,16,1, is_log_softmax>;\n        break;\n    case 5: // 32\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,32,1, is_log_softmax>;\n        break;\n    case 6: // 64\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,2,32,1, is_log_softmax>;\n        break;\n    case 7: // 128\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,4,32,1, is_log_softmax>;\n        break;\n    case 8: // 256\n\tif (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,1, is_log_softmax>;\n        break;\n    case 9: // 512\n        if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,1, is_log_softmax>;\n        break;\n    case 10: // 1024\n        if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,1, is_log_softmax>;\n        break;\n    case 11: // 2048\n        if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,1, is_log_softmax>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nbool dispatch_masked_scale_softmax_backward_recompute(output_t *grad_input, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int pad_batch_stride, int batch_count, cudaStream_t streamid)\n{\n\t\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 2048) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> kernel;\n        int warp_size, batches_per_warp;\n        if (!masked_scale_softmax_warp_backward_recompute_kernel<input_t, output_t, acc_t, is_log_softmax>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n\n        // compute launch size\n        dim3 threads(warp_size, warps_per_block, 1);\n         \n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, softmax_elements_stride, pad_batch_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 1: // 2\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 2: // 4\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 3: // 8\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 4: // 16\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 5: // 32\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 6: // 64\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 7: // 128\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 8: // 256\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 9: // 512\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 10: // 1024\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\n// elementwise multiplication called in at::softmax_backward_data is fused inside softmax dgrad kernel\n// as a result of fusion, intermediate multiplication result is stored in fp32 in registers, instead of fp16\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count)\n{\n    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.\n    constexpr int next_power_of_two = 1 << log2_elements;\n    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n\n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n\n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n\n    // load data from global memory\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]*output[i*element_count+it*WARP_SIZE];\n                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];\n            } else {\n                grad_reg[i][it] = acc_t(0);\n                output_reg[i][it] = acc_t(0);\n            }\n        }\n    }\n\n    acc_t sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        sum[i] = grad_reg[i][0]; //* output_reg[i][0];\n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            sum[i] += grad_reg[i][it];// * output_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                if (is_log_softmax) {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n                } else {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n                }\n            }\n        }\n    }\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_softmax_backward_fused_native(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 1: // 2\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 2: // 4\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 3: // 8\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 4: // 16\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 5: // 32\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 6: // 64\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 7: // 128\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 8: // 256\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 9: // 512\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 10: // 1024\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp softmax backward\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, int batch_size, int stride, int element_count)\n{\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n \n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n \n    // load data from global memory\n    input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);\n                copy_vector<input_t,ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE);\n            }\n \n        }\n    }\n \n    // convert half to floating point\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            grad_reg[i][it] = grad_reg_input[i][it];\n            output_reg[i][it] = output_reg_input[i][it];\n        }\n    }\n \n \n    // compute thread local sum\n    acc_t sum[WARP_BATCH] = {0};\n    #pragma unroll\n    for (int it = 0;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += grad_reg[i][it] * output_reg[i][it];\n \n        }\n    }\n \n    // reduction sum\n    constexpr uint32_t FULL_MASK = 0xffffffff;\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = (output_reg[i][it+element] * (grad_reg[i][it+element] - sum[i]));\n                }\n                // store them in global memory\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);\n            }\n        }\n    }\n}\n \n \n \n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing softmax_backward_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, softmax_backward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        softmax_backward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n        softmax_backward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void masked_softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)\n{\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n \n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n \n    // load data from global memory\n    input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);\n                copy_vector<input_t,ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE);\n            }\n \n        }\n    }\n \n    // convert half to floating point\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            grad_reg[i][it] = grad_reg_input[i][it];\n            output_reg[i][it] = output_reg_input[i][it];\n        }\n    }\n \n \n    // compute thread local sum\n    acc_t sum[WARP_BATCH] = {0};\n    #pragma unroll\n    for (int it = 0;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += grad_reg[i][it] * output_reg[i][it];\n \n        }\n    }\n \n    // reduction sum\n    constexpr uint32_t FULL_MASK = 0xffffffff;\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const uint8_t* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = (output_reg[i][it+element] * (grad_reg[i][it+element] - sum[i]));\n                }\n                // store them in global memory\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                // It is kind of unfortunate this has to be here to zero something out that is close to\n                // zero in the first place\n                apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&out[0], 0.0, curr_mask + itr_jmp);\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + itr_idx, out);\n            }\n        }\n    }\n}\n \n \n \n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing masked_softmax_backward_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_masked_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, masked_softmax_backward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        masked_softmax_backward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_masked_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/multihead_attn/strided_batched_gemm.h",
    "content": "#include <vector>\n#include <iostream>\n\n//#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/gemm/gemm.h\"\n#include \"cutlass/gemm/wmma_gemm_traits.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\ncublasOperation_t convertTransToCublasOperation(char trans) {\n  if (trans == 't') return CUBLAS_OP_T;\n  else if (trans == 'n') return CUBLAS_OP_N;\n  else if (trans == 'c') return CUBLAS_OP_C;\n  else {\n    THError(\"trans must be one of: t, n, c\");\n    return CUBLAS_OP_T;\n  }\n}\n\nvoid CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, long n, long k,\n                    float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                    float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) {\n    cublasOperation_t opa = convertTransToCublasOperation(transa);\n    cublasOperation_t opb = convertTransToCublasOperation(transb);\n\n    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n    cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n    cublasSetStream(handle, stream);\n    float fAlpha = alpha;\n    float fBeta = beta;\n    //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n    THCublasCheck(cublasGemmStridedBatchedEx(handle,\n                                     opa, opb, (int)m, (int)n, (int)k,\n                                     (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,\n                                     b, CUDA_R_16F, (int)ldb, strideB,\n                                     (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,\n                                     (int)batchCount, CUDA_R_32F, algo));\n    //THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n}\n\ntemplate<cutlass::MatrixLayout::Kind A_LAYOUT, cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>\nvoid CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,\n                          float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                          float beta, half *c, long ldc, long strideC, long batchCount) {\n  //printf(\"CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\\n\", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);\n  typedef cutlass::gemm::WmmaGemmTraits<\n    A_LAYOUT,\n    B_LAYOUT,\n    cutlass::Shape<32, 16, 16>,\n    half,\n    half,\n    half,\n    cutlass::gemm::LinearScaling<float>,\n    float,\n    typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,\n    typename cutlass::Shape<16, 16, 16>,\n    SRC_A,   //kScalarsPerLdgA_\n    SRC_B,   //kScalarsPerLdgB_\n    SRC_A,   //KScalarsPerLdsA_\n    SRC_B,   //KScalarsPerLdsB_\n    DST_C,   //kScalarsPerLdgCAndStgD_\n    DST_C/2, //kScalarsPerStsD_\n    DST_C/2  //kScalarsPerLdsD_\n  >\n    WmmaGemmTraits;\n\n  typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm;\n  typename Gemm::Params params;\n\n\n  int result = params.initialize(\n    m,                  // M dimension for each batch\n    n,                  // N dimension for each batch\n    k,                  // K dimension for each batch\n    alpha,              // scalar alpha\n    a,\n    lda,\n    strideA,     // distance in memory between the first element of neighboring batch\n    b,\n    ldb,\n    strideB,     // distance in memory between the first element of neighboring batch\n    beta,               // scalar beta\n    c,                  // source matrix C\n    ldc,\n    strideC,     // distance in memory between the first element of neighboring batch\n    c,                  // destination matrix C (may be different memory than source C matrix)\n    ldc,\n    strideC,    // distance in memory between the first element of neighboring batch\n    batchCount\n  );\n\n  AT_ASSERTM(result == 0, \"Failed to initialize CUTLASS Gemm::Params object.\");\n  \n  // batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits. \n  // To implement batched GEMM with larger batch size, we fragment it into\n  // smaller batched GEMMs of gridDim.z <= 64k\n  long batchesLeft    = batchCount;\n  long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));\n  \n  do {\n  \t //printf(\"CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\\n\", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);\n    int result = params.initialize(\n      m,                  // M dimension for each batch\n      n,                  // N dimension for each batch\n      k,                  // K dimension for each batch\n      alpha,              // scalar alpha\n      a,\n      lda,\n      strideA,     // distance in memory between the first element of neighboring batch\n      b,\n      ldb,\n      strideB,     // distance in memory between the first element of neighboring batch\n      beta,               // scalar beta\n      c,                  // source matrix C\n      ldc,\n      strideC,     // distance in memory between the first element of neighboring batch\n      c,                  // destination matrix C (may be different memory than source C matrix)\n      ldc,\n      strideC,    // distance in memory between the first element of neighboring batch\n      iterBatchCount\n    );\n\n    AT_ASSERTM(result == 0, \"Failed to initialize CUTLASS Gemm::Params object.\");\n    // Launch the CUTLASS GEMM kernel.\n    THCudaCheck(Gemm::launch(params, stream));\n\n    // Update batched GEMM params based on completed work\n    batchesLeft = batchesLeft - iterBatchCount;\n    a += iterBatchCount * strideA;\n    b += iterBatchCount * strideB;\n    c += iterBatchCount * strideC;;\n\n    iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));\n    \n  } while(batchesLeft > 0);\n}\n\nvoid gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,\n                           float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                           float beta, half *c, long ldc, long strideC, long batchCount) {\n  auto stream = c10::cuda::getCurrentCUDAStream();\n  //printf(\"GEMM   -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\\n\", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);\n  if        ( (transa == 't') && (transb == 'n') ) { \n    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }\n    /*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      int m_rem = m % 64;\n      int n_rem = n % 64;\n      if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else {\n        CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); \n      }\n    }*/\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else                                                   { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n  } else if ( (transa == 'n') && (transb == 'n') ) {\n    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }\n    /*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      int m_rem = m % 64;\n      int n_rem = n % 64;\n      if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else {\n        CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n      }\n    }*/\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else                                                   { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n  } else if ( (transa == 'n') && (transb == 't') ) {\n    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }\n    /*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { \n      int m_rem = m % 64;\n      int n_rem = n % 64;\n      if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); \n      } else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); \n      } else {\n        CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); \n      }\n    }*/\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else                                                   { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n  } else {\n    AT_ASSERTM(false, \"TransA and TransB are invalid\");\n  }\n}\n\nvoid adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc)\n{\n  int transa_ = ((transa == 't') || (transa == 'T'));\n  int transb_ = ((transb == 't') || (transb == 'T'));\n\n  // Note: leading dimensions generally are checked that they are > 0 and at least as big the result\n  // requires (even if the value won't be used).\n  if(n <= 1)\n    *ldc = std::max<int64_t>(m, 1);\n\n  if(transa_)\n  {\n    if(m <= 1)\n      *lda = std::max<int64_t>(k, 1);\n  }\n  else\n  {\n    if(k <= 1)\n      *lda = std::max<int64_t>(m, 1);\n  }\n\n  if(transb_)\n  {\n    if(k <= 1)\n      *ldb = std::max<int64_t>(n, 1);\n  }\n  else\n  {\n    if(n <= 1)\n      *ldb = std::max<int64_t>(k, 1);\n  }\n\n}\n\nvoid HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,\n                             float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                             float beta, half *c, long ldc, long strideC, long batchCount)\n{\n  if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX)  || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )\n\n  {\n    THError(\"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount\"\n            \"with the bound [val] <= %d\", INT_MAX);\n  }\n\n  adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);\n\n  //gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n  gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n}\n\n/******\nat::Tensor strided_batched_gemm_cuda(\n    float beta,\n    at::Tensor in_result,\n    float alpha,\n    at::Tensor batch1,\n    at::Tensor batch2) {\n\n  bool transpose_result;\n  char transpose_batch1, transpose_batch2;\n  int64_t lda, ldb, ldc;\n  at::Tensor result, input1, input2;\n  if (in_result.stride(1) == 1)\n  {\n    transpose_result = false;\n    result = in_result;\n    ldc = result.stride(2);\n  }\n  else if (in_result.stride(2) == 1)\n  {\n    transpose_result = true;\n\n    at::Tensor swap = batch2;\n    batch2 = batch1;\n    batch1 = swap;\n\n    result = in_result;\n    ldc = result.stride(1);\n  } else { \n    AT_ASSERTM(false, \"result should be contiguous\");\n  }\n\n  if (batch1.stride(transpose_result ? 2 : 1) == 1 &&\n      batch1.stride(transpose_result ? 1 : 2) != 0) {\n    transpose_batch1 = 'n';\n    input1 = batch1;\n    lda = input1.stride(transpose_result ? 1 : 2);\n  } else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&\n             batch1.stride(transpose_result ? 2 : 1) != 0) {\n    transpose_batch1 = 't';\n    input1 = batch1;\n    lda = input1.stride(transpose_result ? 2 : 1);\n  } else {\n    AT_ASSERTM(false, \"input1 should be contiguous\");\n  }\n\n  if (batch2.stride(transpose_result ? 2 : 1) == 1 &&\n      batch2.stride(transpose_result ? 1 : 2) != 0) {\n    transpose_batch2 = 'n';\n    input2 = batch2;\n    ldb = input2.stride(transpose_result ? 1 : 2);\n  } else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&\n             batch2.stride(transpose_result ? 2 : 1) != 0) {\n    transpose_batch2 = 't';\n    input2 = batch2;\n    ldb = input2.stride(transpose_result ? 2 : 1);\n  } else {\n    AT_ASSERTM(false, \"input2 should be contiguous\");\n  }\n  int64_t num_batches = result.size(0);\n\n  HgemmStridedBatched(\n      state,\n      transpose_batch1,\n      transpose_batch2,\n      result.size(transpose_result ? 2 : 1),\n      result.size(transpose_result ? 1 : 2),\n      input1.size(transpose_result ? 1 : 2),\n      alpha,\n      static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),\n      static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),\n      beta,\n      static_cast<half*>(result.data_ptr()), ldc, result.stride(0),\n      num_batches);\n\n  return in_result;\n}\n\n***/\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp",
    "content": "#include <torch/extension.h>\n\n// CUDA forward declaration\nvoid fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);\n\nvoid fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\nvoid fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\nvoid fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\n\nvoid fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\n\nvoid maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);\nvoid maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\n// C++ interface\nvoid strided_check_finite(\n\t\tat::Tensor& overflow_flag,\n\t\tat::Tensor& p_copy,\n\t\tint stride,\n\t\tint clear_overflow_first\n\t ) {\n\tCHECK_INPUT(p_copy);\n\tfused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);\n}\nvoid adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {\n        CHECK_INPUT(p);\n        if (p_copy.numel() > 0) CHECK_INPUT(p_copy);\n        CHECK_INPUT(m);\n        CHECK_INPUT(v);\n        CHECK_INPUT(g);\n        int64_t num_elem = p.numel();\n        AT_ASSERTM(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n        AT_ASSERTM(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n        AT_ASSERTM(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n        AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, \"number of elements in p_copy and p tensors should be equal, or p_copy should be empty\");\n\n        fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid reversible_adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {\n        CHECK_INPUT(p);\n        if (p_copy.numel() > 0) CHECK_INPUT(p_copy);\n        CHECK_INPUT(m);\n        CHECK_INPUT(v);\n        CHECK_INPUT(g);\n        int64_t num_elem = p.numel();\n        AT_ASSERTM(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n        AT_ASSERTM(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n        AT_ASSERTM(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n        AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, \"number of elements in p_copy and p tensors should be equal, or p_copy should be empty\");\n\n        fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {\n        CHECK_INPUT(p);\n        CHECK_INPUT(m);\n        CHECK_INPUT(v);\n        CHECK_INPUT(g);\n        int64_t num_elem = p.numel();\n        AT_ASSERTM(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n        AT_ASSERTM(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n        AT_ASSERTM(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n\n        fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) {\n\tCHECK_INPUT(p_in);\n\tCHECK_INPUT(p_out);\n\tint64_t num_elem = p_in.numel();\n\tAT_ASSERTM(p_out.numel() == num_elem, \"number of elements in p_in and p_out should be equal\");\n\n\tmaybe_cast_cuda(overflow_flag, p_in, p_out);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n        m.def(\"strided_check_finite\", &strided_check_finite, \"Strided finite check.\");\n        m.def(\"adam\", &adam, \"Adam optimized CUDA implementation.\");\n        m.def(\"reversible_adam\", &reversible_adam, \"Reversible Adam optimized CUDA implementation.\");\n        m.def(\"adam_mt\", &fused_adam_cuda_mt, \"Multi tensor Adam optimized CUDA implementation.\");\n        m.def(\"maybe_adam_undo\", &maybe_adam_undo, \"Undo function for Adam optimized CUDA implementation.\");\n        m.def(\"maybe_cast\", &maybe_cast, \"Unpack byte tensor containing e5m2 floats.\");\n        m.def(\"maybe_cast_mt\", &maybe_cast_cuda_mt, \"Unpack byte tensor containing e5m2 floats.\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu",
    "content": "#include \"ATen/ATen.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n#include \"ATen/cuda/detail/IndexUtils.cuh\"\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <stdio.h>\n#include <cmath>\n#include \"ATen/TensorUtils.h\"\n// #include \"ATen/Type.h\"\n#include \"ATen/AccumulateType.h\"\n#include <THC/THCGeneral.h>\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate<typename T>\n__device__ __forceinline__ bool is_aligned(T* p){\n  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\n#include \"type_shim.h\"\n\ntypedef enum{\n    ADAM_MODE_0   =0, // eps under square root\n    ADAM_MODE_1   =1  // eps outside square root\n} adamMode_t;\n\ntemplate <typename T, typename GRAD_T>\n__global__ void adam_cuda_kernel(\n        T* __restrict__ p,\n        GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed\n        T* __restrict__ m,\n        T* __restrict__ v,\n        const GRAD_T * __restrict__ g,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        const size_t tsize,\n        adamMode_t mode,\n        const float decay)\n{\n        //Assuming 2D grids and 2D blocks\n        const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n        const int threadsPerBlock = blockDim.x * blockDim.y;\n        const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n        const int i = (blockId * threadsPerBlock + threadIdInBlock);\n        const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n        for (int j = i; j < tsize; j+=totThreads) {\n                T scaled_grad = g[j]/grad_scale;\n                m[j] = b1*m[j] + (1-b1)*scaled_grad;\n                v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;\n                float denom;\n                if (mode == ADAM_MODE_0)\n                    denom = sqrtf(v[j] + eps);\n                else // Mode 1\n                    denom = sqrtf(v[j]) + eps;\n                float update = (m[j]/denom) + (decay*p[j]);\n                p[j] = p[j] - (step_size*update);\n                if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];\n        }\n}\n\ntemplate <int DEPTH, typename T, typename GRAD_T>\nstruct AdamFunctor\n{\n    __device__ __forceinline__ void operator()(\n        int chunk_size,\n        volatile int* noop_gmem,\n        TensorListMetadata<DEPTH>& tl,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        adamMode_t mode,\n        const float decay)\n    {\n        int tensor_loc = tl.block_to_tensor[blockIdx.x];\n        int chunk_idx = tl.block_to_chunk[blockIdx.x];\n        int n = tl.sizes[tensor_loc];\n\n        T* p = (T *)tl.addresses[0][tensor_loc];\n        p += chunk_idx*chunk_size;\n        T* m = (T *)tl.addresses[1][tensor_loc];\n        m += chunk_idx*chunk_size;\n        T* v = (T *)tl.addresses[2][tensor_loc];\n        v += chunk_idx*chunk_size;\n        GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];\n        g += chunk_idx*chunk_size;\n        GRAD_T* p_copy = NULL;\n        if (DEPTH == 5) {\n            p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];\n            p_copy += chunk_idx*chunk_size;\n        }\n\n        n -= chunk_idx*chunk_size;\n\n        T incoming_p[ILP];\n        T incoming_m[ILP];\n        T incoming_v[ILP];\n        T incoming_g[ILP];\n\n        // to make things simple, we put aligned case in a different code path\n        if(n % ILP == 0 &&\n           chunk_size % ILP == 0 &&\n           is_aligned(p) &&\n           is_aligned(m) &&\n           is_aligned(v) &&\n           is_aligned(g) &&\n           is_aligned(p_copy))\n        {\n          for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)\n          {\n            // load\n            GRAD_T tmp_g[ILP];\n            load_store(incoming_p, p, 0, i_start);\n            load_store(incoming_m, m, 0, i_start);\n            load_store(incoming_v, v, 0, i_start);\n            load_store(tmp_g, g, 0, i_start);\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n              incoming_g[ii] = static_cast<T>(tmp_g[ii]);\n              T scaled_grad = incoming_g[ii]/grad_scale;\n              incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n              incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n              float denom;\n              if (mode == ADAM_MODE_0)\n                denom = sqrtf(incoming_v[ii] + eps);\n              else // Mode 1\n                denom = sqrtf(incoming_v[ii]) + eps;\n              float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]);\n              incoming_p[ii] = incoming_p[ii] - (step_size*update);\n              if (DEPTH == 5)  tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);\n            }\n            load_store(p, incoming_p, i_start, 0);\n            load_store(m, incoming_m, i_start, 0);\n            load_store(v, incoming_v, i_start, 0);\n            if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);\n          }\n        }\n        else\n        {\n          for(int i_start = 0;\n              i_start < n && i_start < chunk_size;\n              i_start += blockDim.x*ILP) {\n\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n              incoming_p[ii] = 0;\n              incoming_m[ii] = 0;\n              incoming_v[ii] = 0;\n              incoming_g[ii] = 0;\n\n              int i = i_start + threadIdx.x + ii*blockDim.x;\n              if (i < n && i < chunk_size) {\n                incoming_p[ii] = p[i];\n                incoming_m[ii] = m[i];\n                incoming_v[ii] = v[i];\n                incoming_g[ii] = static_cast<T>(g[i]);\n              }\n            }\n\n            // note for clarification to future michael:\n            // From a pure memory dependency perspective, there's likely no point unrolling\n            // the write loop, since writes just fire off once their LDGs arrive.\n            // Put another way, the STGs are dependent on the LDGs, but not on each other.\n            // There is still compute ILP benefit from unrolling the loop though.\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n              int j = i_start + threadIdx.x + ii*blockDim.x;\n\n              if(j < n && j < chunk_size) {\n                T scaled_grad = incoming_g[ii]/grad_scale;\n                m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n                v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n                float denom;\n                if (mode == ADAM_MODE_0)\n                  denom = sqrtf(v[j] + eps);\n                else // Mode 1\n                  denom = sqrtf(v[j]) + eps;\n                float update = (m[j]/denom) + (decay*incoming_p[ii]);\n                p[j] = incoming_p[ii] - (step_size*update);\n                if (DEPTH == 5)  p_copy[j] = (GRAD_T) p[j];\n              }\n            }\n          }\n        }\n    }\n};\n\nvoid fused_adam_cuda(\n        at::Tensor & p,\n        at::Tensor & p_copy,\n        at::Tensor & m,\n        at::Tensor & v,\n        at::Tensor & g,\n        float lr,\n        float beta1,\n        float beta2,\n        float eps,\n        float grad_scale,\n        int step,\n        int mode,\n        int bias_correction,\n        float decay)\n{\n//        using namespace at;\n\n        //Get tensor size\n        int tsize = p.numel();\n        //Determine #threads and #blocks\n        const int threadsPerBlock = 512;\n        const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n        AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n        //Constants\n        float step_size = 0;\n        if (bias_correction == 1) {\n            const float bias_correction1 = 1 - std::pow(beta1, step);\n            const float bias_correction2 = 1 - std::pow(beta2, step);\n            step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n        }\n        else {\n            step_size = lr;\n        }\n        cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n        if (g.scalar_type() == at::ScalarType::Half) {\n//all other values should be fp32 for half gradients\n            AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n//dispatch is done on the gradient type\n            using namespace at; // prevents \"toString is undefined\" errors\n            DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                        p.DATA_PTR<accscalar_t>(),\n                        p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,\n                        m.DATA_PTR<accscalar_t>(),\n                        v.DATA_PTR<accscalar_t>(),\n                        g.DATA_PTR<scalar_t_0>(),\n                        beta1,\n                        beta2,\n                        eps,\n                        grad_scale,\n                        step_size,\n                        tsize,\n                        (adamMode_t) mode,\n                        decay);\n                );\n      } else {\n            using namespace at;\n            DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                        p.DATA_PTR<scalar_t_0>(),\n                        NULL, //don't output p_copy for fp32, it's wasted write\n                        m.DATA_PTR<scalar_t_0>(),\n                        v.DATA_PTR<scalar_t_0>(),\n                        g.DATA_PTR<scalar_t_0>(),\n                        beta1,\n                        beta2,\n                        eps,\n                        grad_scale,\n                        step_size,\n                        tsize,\n                        (adamMode_t) mode,\n                        decay);\n            );\n      }\n      THCudaCheck(cudaGetLastError());\n\n}\n\nvoid fused_adam_cuda_mt(\n    int chunk_size,\n    at::Tensor noop_flag,\n    std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy\n    float lr,\n    float beta1,\n    float beta2,\n    float eps,\n    float grad_scale,\n    int step,\n    int mode,\n    int bias_correction,\n    float decay) {\n\n    //Constants\n    float step_size = 0;\n    if (bias_correction == 1) {\n        const float bias_correction1 = 1 - std::pow(beta1, step);\n        const float bias_correction2 = 1 - std::pow(beta2, step);\n        step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n    }\n    else {\n        step_size = lr;\n    }\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    size_t tl_sz = tensor_lists.size();\n    AT_ASSERTM(tl_sz == 4 || tl_sz == 5, \"expected tensor lists of size 4 or 5\");\n\n    if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {\n//alher values should be fp32 for half gradients\n        AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n//dich is done on the gradient type\n        if (tl_sz == 5) {\n            DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                multi_tensor_apply<5>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<5, accscalar_t, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        } else {\n            DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                multi_tensor_apply<4>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<4, accscalar_t, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        }\n    } else {\n        if (tl_sz == 5) {\n            DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                multi_tensor_apply<5>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<5, scalar_t_0, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        } else {\n            DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                multi_tensor_apply<4>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<4, scalar_t_0, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        }\n    }\n    THCudaCheck(cudaGetLastError());\n}\n\ntemplate <typename FROM_T, typename TO_T> \n__device__ void convert(const FROM_T vi, TO_T& vo)\n{\n    vo = static_cast<TO_T>(vi);\n}\n\ntemplate <>\n__device__ void convert(const float vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = vi;\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, float& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = static_cast<float>(t.as_half);\n}\n\ntemplate <>\n__device__ void convert(const at::Half vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = static_cast<float>(vi);\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, at::Half& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = t.as_half;\n}\n\ntemplate <typename GRAD_T>\n__global__ void strided_check_finite_cuda_kernel(\n        volatile int* noop_gmem,\n        GRAD_T* __restrict__ p_copy,\n        const size_t tsize,\n        int stride,\n        int clear_overflow_first)\n{\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;\n\n    if (clear_overflow_first) {\n        if (i == 0) {\n            *noop_gmem = 0;\n        }\n        __syncthreads();\n    }\n\n    for (int j = i; j < tsize; j+=totThreads) {\n        GRAD_T pi = p_copy[j];\n        if (!isfinite(pi)) {\n            *noop_gmem = 1;\n        }\n    }\n}\ntemplate <>\n__global__ void strided_check_finite_cuda_kernel(\n        volatile int* noop_gmem,\n        uint8_t* __restrict__ p_copy,\n        const size_t tsize,\n        int stride,\n        int clear_overflow_first)\n{\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;\n\n    if (clear_overflow_first) {\n        if (i == 0) {\n            *noop_gmem = 0;\n        }\n        __syncthreads();\n    }\n\n    for (int j = i; j < tsize; j+=totThreads) {\n        at::Half pi;\n        convert(p_copy[j], pi);\n        if (!isfinite(pi)) {\n            *noop_gmem = 1;\n        }\n    }\n}\n\ntemplate <typename FROM_T, typename TO_T> \n__global__ void maybe_cast_kernel(\n        volatile int* overflow_flag,\n        const FROM_T* p_in,\n        TO_T* p_out,\n        const size_t tsize)\n{\n    if (overflow_flag && *overflow_flag != 0) return;\n\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock);\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n    FROM_T pi[ILP];\n    TO_T po[ILP];\n\n    for(int j_start = 0;  j_start < tsize;  j_start+=totThreads*ILP) {\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            pi[ii] = 0;\n\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                pi[ii] = p_in[j];\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            convert(pi[ii], po[ii]);\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                p_out[j] = po[ii];\n            }\n        }\n    }\n}\n\ntemplate <typename T, typename GRAD_T, typename REDU_T>\n__global__ void reversible_adam_cuda_kernel(\n        T* __restrict__ p,\n        REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed\n        T* __restrict__ m,\n        T* __restrict__ v,\n        const GRAD_T * __restrict__ g,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        const size_t tsize,\n        adamMode_t mode,\n        const float decay)\n{\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock);\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n    T mi[ILP];\n    T vi[ILP];\n    T pi[ILP];\n    T gi[ILP];\n\n    bool overflow = false;\n    for(int j_start = 0;  j_start < tsize;  j_start+=totThreads*ILP) {\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            mi[ii] = T(0);\n            vi[ii] = T(0);\n            pi[ii] = T(0);\n            gi[ii] = GRAD_T(0);\n\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                pi[ii] = p[j];\n                mi[ii] = m[j];\n                vi[ii] = v[j];\n                gi[ii] = static_cast<T>(g[j]);\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            T scaled_grad = gi[ii]/grad_scale;\n            if (isfinite(scaled_grad)) {\n                mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;\n                vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;\n                float denom;\n                if (mode == ADAM_MODE_0)\n                    denom = sqrtf(vi[ii] + eps);\n                else // Mode 1\n                    denom = sqrtf(vi[ii]) + eps;\n                float update = (mi[ii]/denom) + (decay*pi[ii]);\n                pi[ii] = pi[ii] - (step_size*update);\n            } else {\n                overflow = true;\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                m[j] = mi[ii];\n                v[j] = vi[ii];\n                p[j] = pi[ii];\n                if (p_copy != NULL) {\n                    convert(pi[ii], p_copy[j]);\n                }\n            }\n        }\n    }\n\n    if (p_copy != NULL) {\n        __syncthreads();\n        if (overflow) {\n            convert(float(INFINITY), p_copy[0]);\n        }\n    }\n}\n\ntemplate <typename T, typename GRAD_T>\n__global__ void maybe_adam_undo_cuda_kernel(\n        volatile int* overflow_flag,\n        T* __restrict__ p,\n        T* __restrict__ m,\n        T* __restrict__ v,\n        const GRAD_T * __restrict__ g,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        const size_t tsize,\n        adamMode_t mode,\n        const float decay)\n{\n    // NB! Skip undo kernel when overflow flag is NOT set\n    if (overflow_flag && *overflow_flag == 0) return;\n\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock);\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n    T mi[ILP];\n    T vi[ILP];\n    T pi[ILP];\n    T gi[ILP];\n\n    for(int j_start = 0;  j_start < tsize;  j_start+=totThreads*ILP) {\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            mi[ii] = T(0);\n            vi[ii] = T(0);\n            pi[ii] = T(0);\n            gi[ii] = GRAD_T(0);\n\n            int j = j_start + i*ILP;\n            if (j < tsize) {\n                pi[ii] = p[j];\n                mi[ii] = m[j];\n                vi[ii] = v[j];\n                gi[ii] = static_cast<T>(g[j]);\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            T scaled_grad = gi[ii]/grad_scale;\n            if (isfinite(scaled_grad)) {\n                float denom;\n                if (mode == ADAM_MODE_0)\n                    denom = sqrtf(vi[ii] + eps);\n                else // Mode 1\n                    denom = sqrtf(vi[ii]) + eps;\n                pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay);\n                mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1;\n                vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;\n                // Make sure round off errors don't create (small) negative value.\n                // This can happen if we have to revert the very first step.\n                vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            int j = j_start + i*ILP;\n            if (j < tsize) {\n                m[j] = mi[ii];\n                v[j] = vi[ii];\n                p[j] = pi[ii];\n            }\n        }\n    }\n}\n\ntemplate <int DEPTH, typename FROM_T, typename TO_T>\nstruct MaybeCastFunctor\n{\n    __device__ __forceinline__ void operator()(\n        int chunk_size,\n        volatile int* overflow_flag,\n        TensorListMetadata<DEPTH>& tl)\n    {\n        if (overflow_flag && *overflow_flag != 0) return;\n\n        int tensor_loc = tl.block_to_tensor[blockIdx.x];\n        int chunk_idx = tl.block_to_chunk[blockIdx.x];\n        int n = tl.sizes[tensor_loc];\n\n        FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];\n        p_in += chunk_idx*chunk_size;\n        TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];\n        p_out += chunk_idx*chunk_size;\n\n        n -= chunk_idx*chunk_size;\n        int dim = chunk_size < n ? chunk_size : n;\n\n\tFROM_T pi[ILP];\n        TO_T po[ILP];\n\n        for(int j_start = 0;  j_start < dim;  j_start+=blockDim.x*ILP) {\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n                pi[ii] = FROM_T(0);\n                int j = j_start + threadIdx.x + ii*blockDim.x;\n                if (j < dim) {\n                    pi[ii] = p_in[j];\n                }\n            }\n\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n                convert(pi[ii], po[ii]);\n            }\n\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n                int j = j_start + threadIdx.x + ii*blockDim.x;\n                if (j < dim) {\n                    p_out[j] = po[ii];\n                }\n            }\n        }\n    }\n};\n\nvoid fused_strided_check_finite(\n\tat::Tensor & overflow_flag,\n        at::Tensor & p_copy,\n        int stride,\n\tint clear_overflow_first)\n{\n\t//Get tensor size\n\tint tsize = p_copy.numel();\n\tint niter = (tsize + stride - 1) / stride;\n\n\t//Determine #threads and #blocks\n\tconst int threadsPerBlock = 512;\n\t//In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set.\n\tconst dim3 blocks(clear_overflow_first ? 1 : (niter+threadsPerBlock-1)/threadsPerBlock);\n\tAT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), \"parameter tensor is too large to be indexed with int32\");\n\n\tcudaStream_t stream = at::cuda::getCurrentCUDAStream();\n        using namespace at; // prevents \"toString is undefined\" errors\n        DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, \"check_finite_cuda_kernel\",\n                strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                    overflow_flag.DATA_PTR<int>(),\n                    p_copy.DATA_PTR<scalar_t_0>(),\n                    tsize,\n                    stride,\n                    clear_overflow_first);\n                );\n\tTHCudaCheck(cudaGetLastError());\n}\n\nvoid fused_reversible_adam_cuda(\n        at::Tensor & p,\n        at::Tensor & p_copy,\n        at::Tensor & m,\n        at::Tensor & v,\n        at::Tensor & g,\n        float lr,\n        float beta1,\n        float beta2,\n        float eps,\n        float grad_scale,\n        int step,\n        int mode,\n        int bias_correction,\n        float decay)\n{\n//      using namespace at;\n\n      //Get tensor size\n      int tsize = p.numel();\n      //Determine #threads and #blocks\n      const int threadsPerBlock = 512;\n      const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n      AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n      //Constants\n      float step_size = 0;\n      if (bias_correction == 1) {\n          const float bias_correction1 = 1 - std::pow(beta1, step);\n          const float bias_correction2 = 1 - std::pow(beta2, step);\n          step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n      }\n      else {\n          step_size = lr;\n      }\n      cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n      if (g.scalar_type() == at::ScalarType::Half) {\n          //all other values should be fp32 for half gradients\n          AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n          //dispatch is done on the gradient type\n          using namespace at; // prevents \"toString is undefined\" errors\n          if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {\n              DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                      using accscalar_t = at::acc_type<scalar_t_0, true>;\n                      reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                          p.DATA_PTR<accscalar_t>(),\n                          p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,\n                          m.DATA_PTR<accscalar_t>(),\n                          v.DATA_PTR<accscalar_t>(),\n                          g.DATA_PTR<scalar_t_0>(),\n                          beta1,\n                          beta2,\n                          eps,\n                          grad_scale,\n                          step_size,\n                          tsize,\n                          (adamMode_t) mode,\n                          decay);\n                      );\n          } else {\n              AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, \"expected parameter to be of byte type\");\n              DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_e5m2_kernel\",\n                      using accscalar_t = at::acc_type<scalar_t_0, true>;\n                      reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(\n                          p.DATA_PTR<accscalar_t>(),\n                          p_copy.DATA_PTR<uint8_t>(),\n                          m.DATA_PTR<accscalar_t>(),\n                          v.DATA_PTR<accscalar_t>(),\n                          g.DATA_PTR<scalar_t_0>(),\n                          beta1,\n                          beta2,\n                          eps,\n                          grad_scale,\n                          step_size,\n                          tsize,\n                          (adamMode_t) mode,\n                          decay);\n                      );\n          }\n      } else {\n          using namespace at;\n          DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                  reversible_adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                      p.DATA_PTR<scalar_t_0>(),\n                      NULL, //don't output p_copy for fp32, it's wasted write\n                      m.DATA_PTR<scalar_t_0>(),\n                      v.DATA_PTR<scalar_t_0>(),\n                      g.DATA_PTR<scalar_t_0>(),\n                      beta1,\n                      beta2,\n                      eps,\n                      grad_scale,\n                      step_size,\n                      tsize,\n                      (adamMode_t) mode,\n                      decay);\n                  );\n      }\n      THCudaCheck(cudaGetLastError());\n}\n\nvoid maybe_cast_cuda(\n        at::Tensor & overflow_flag,\n        at::Tensor & p_in,\n        at::Tensor & p_out)\n{\n      //Get tensor size\n      int tsize = p_in.numel();\n      AT_ASSERTM(tsize == p_out.numel(), \"p_in.numel() must equal p_out.numel()\");\n      //Determine #threads and #blocks\n      const int threadsPerBlock = 512;\n      const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n      AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), \"parameter tensor is too large to be indexed with int32\");\n      //Constants\n      cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n      DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, \"maybe_cast_cuda\"\n              DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, \"maybe_cast_cuda\",\n                  maybe_cast_kernel<scalar_t_0,scalar_t_1><<<blocks,threadsPerBlock, 0, stream>>>(\n                      overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,\n                      p_in.DATA_PTR<scalar_t_0>(),\n                      p_out.DATA_PTR<scalar_t_1>(),\n                      tsize); ))\n      THCudaCheck(cudaGetLastError());\n}\n\nvoid maybe_cast_cuda_mt(\n    int chunk_size,\n    at::Tensor overflow_flag,\n    std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out\n{\n    //Constants\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    size_t tl_sz = tensor_lists.size();\n    AT_ASSERTM(tl_sz == 2, \"expected tensor lists of size 2\");\n\n    DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, \"maybe_cast_cuda_mt_kernel\",\n            DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, \"maybe_cast_cuda_mt_kernel\",\n                multi_tensor_apply<2>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    overflow_flag,\n                    tensor_lists,\n                    MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); ))\n    THCudaCheck(cudaGetLastError());\n}\n\nvoid fused_maybe_adam_undo_cuda(\n        at::Tensor & overflow_flag,\n        at::Tensor & p,\n        at::Tensor & m,\n        at::Tensor & v,\n        at::Tensor & g,\n        float lr,\n        float beta1,\n        float beta2,\n        float eps,\n        float grad_scale,\n        int step,\n        int mode,\n        int bias_correction,\n        float decay)\n{\n    //Get tensor size\n    int tsize = p.numel();\n    //Determine #threads and #blocks\n    const int threadsPerBlock = 512;\n    const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n    AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n    //Constants\n    float step_size = 0;\n    if (bias_correction == 1) {\n        const float bias_correction1 = 1 - std::pow(beta1, step);\n        const float bias_correction2 = 1 - std::pow(beta2, step);\n        step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n    }\n    else {\n        step_size = lr;\n    }\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    if (g.scalar_type() == at::ScalarType::Half) {\n        //all other values should be fp32 for half gradients\n        AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n        //dispatch is done on the gradient type\n        using namespace at; // prevents \"toString is undefined\" errors\n        DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                    overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,\n                    p.DATA_PTR<accscalar_t>(),\n                    m.DATA_PTR<accscalar_t>(),\n                    v.DATA_PTR<accscalar_t>(),\n                    g.DATA_PTR<scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    tsize,\n                    (adamMode_t) mode,\n                    decay);\n                );\n    } else {\n        using namespace at;\n        DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                maybe_adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                    overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,\n                    p.DATA_PTR<scalar_t_0>(),\n                    m.DATA_PTR<scalar_t_0>(),\n                    v.DATA_PTR<scalar_t_0>(),\n                    g.DATA_PTR<scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    tsize,\n                    (adamMode_t) mode,\n                    decay);\n                );\n    }\n    THCudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  const float lr,\n  const float beta1,\n  const float beta2,\n  const float epsilon,\n  const int step,\n  const int bias_correction,\n  const float weight_decay,\n  const int grad_averaging,\n  const int mode,\n  const float global_grad_norm,\n  const float max_grad_norm);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n        m.def(\"lamb\", &multi_tensor_lamb_cuda, \"Computes and apply update for LAMB optimizer\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"type_shim.h\"\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntypedef enum{\n  MOMENT_MODE_0   =0, // L2 regularization mode\n  MOMENT_MODE_1   =1  // Decoupled weight decay mode\n} adamMode_t;\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::optional<bool> per_tensor_python);\n\nusing MATH_T = float;\n\ntemplate<typename T>\nstruct LAMBStage1Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<4>& tl,\n    const float beta1,\n    const float beta2,\n    const float beta3,\n    const float beta1_correction,\n    const float beta2_correction,\n    const float epsilon,\n    adamMode_t mode,\n    const float decay,\n    const float global_grad_norm,\n    const float max_global_grad_norm)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx*chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for(int i_start = 0;\n            i_start < n && i_start < chunk_size;\n            i_start += blockDim.x*ILP)\n    {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_m[ILP];\n      MATH_T r_v[ILP];\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        int i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          r_g[ii] = g[i];\n          // special ?optimization? for lamb stage 1\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          }\n          else {\n            r_p[ii] = p[i];\n          }\n          r_m[ii] = m[i];\n          r_v[ii] = v[i];\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_m[ii] = MATH_T(0);\n          r_v[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        if (mode == MOMENT_MODE_0) {\n\t  MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n\t  // L2 on scaled grad\n          scaled_grad = scaled_grad + decay*r_p[ii];\n          r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n          r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          r_p[ii] = next_m_unbiased / denom;\n        }\n        else {\n          MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n          r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n          r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);\n        }\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        int i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          g[i] = r_p[ii];\n          m[i] = r_m[ii];\n          v[i] = r_v[ii];\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate<typename T>\nstruct LAMBStage2Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<2>& tl,\n    const float* per_tensor_param_norm,\n    const float* per_tensor_update_norm,\n    const float learning_rate,\n    const float decay)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T ratio = learning_rate;\n    // apply adaptive learning rate to parameters with non-zero weight decay\n    if (decay != 0.0) \n    {\n      float param_norm = per_tensor_param_norm[tensor_num];\n      float update_norm = per_tensor_update_norm[tensor_num];\n      ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;\n    }\n\n    T* update = (T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    for(int i_start = 0;\n            i_start < n && i_start < chunk_size;\n            i_start += blockDim.x*ILP)\n    {\n      MATH_T r_p[ILP];\n      MATH_T r_update[ILP];\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n       \tint i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          r_p[ii] = p[i];\n          r_update[ii] = update[i];\n        }\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n       \tr_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        int i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          p[i] = r_p[ii];\n        }\n      }\n    }\n  }\n};\n\n\nvoid multi_tensor_lamb_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  const float lr,\n  const float beta1,\n  const float beta2,\n  const float epsilon,\n  const int step,\n  const int bias_correction,\n  const float weight_decay,\n  const int grad_averaging,\n  const int mode,\n  const float global_grad_norm,\n  const float max_grad_norm)\n{\n  using namespace at;\n  // Master weight and 32bit momentum(potentially changing) is not handled by this\n  // So we assume every tensor are all in the same type\n\n  // Handle bias correction mode\n  float bias_correction1 = 1.0f, bias_correction2 = 1.0f;\n  if (bias_correction == 1) {\n    bias_correction1 = 1 - std::pow(beta1, step);\n    bias_correction2 = 1 - std::pow(beta2, step);\n  }\n\n  // Handle grad averaging mode\n  float beta3 = 1.0f;\n  if (grad_averaging == 1) beta3 = 1 - beta1;\n\n  std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);\n  std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);\n\n  // Compute per tensor param norm\n  auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);\n\n  // We now in-place modify grad to store update before compute its norm\n  // Generally this is not a issue since people modify grad in step() method all the time\n  // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_1\",\n      multi_tensor_apply<4>(\n        BLOCK_SIZE,\n        chunk_size,\n        noop_flag,\n        tensor_lists,\n        LAMBStage1Functor<scalar_t_0>(),\n        beta1,\n        beta2,\n        beta3, // 1-beta1 or 1 depends on averaging mode\n        bias_correction1,\n        bias_correction2,\n        epsilon,\n        (adamMode_t) mode,\n        weight_decay,\n        global_grad_norm,\n        max_grad_norm); )\n\n  // Compute update norms\n  auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);\n\n  std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);\n\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_2\",\n      multi_tensor_apply<2>(\n        BLOCK_SIZE,\n        chunk_size,\n       \tnoop_flag,\n        grad_param_list,\n        LAMBStage2Functor<scalar_t_0>(),\n        std::get<1>(param_norm_tuple).DATA_PTR<float>(),\n        std::get<1>(update_norm_tuple).DATA_PTR<float>(),\n        lr,\n\tweight_decay); )\n\n  AT_CUDA_CHECK(cudaGetLastError());\n\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_fused_adam_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor per_tensor_eps,\n  at::Tensor per_tensor_weight_decay,\n  float lr,\n  float grad_scale,\n  int step,\n  int mode);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"multi_tensor_fused_adam\", &multi_tensor_fused_adam_cuda,\n        \"Multi tensor Adam optimized CUDA implementation.\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <THC/THCGeneral.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n#include <cmath>\n#include \"type_shim.h\"\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate<typename T>\n__device__ __forceinline__ bool is_aligned(T* p){\n  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntypedef enum{\n  ADAM_MODE_0   =0, // eps under square root\n  ADAM_MODE_1   =1  // eps outside square root\n} adamMode_t;\n\ntemplate <int DEPTH, typename T, typename GRAD_T>\nstruct DistAdamFunctor\n{\n  __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<DEPTH>& tl,\n    const float* per_tensor_beta1,\n    const float* per_tensor_beta2,\n    const int* per_tensor_bias_correction,\n    const float* per_tensor_eps,\n    const float* per_tensor_weight_decay,\n    const float lr,\n    const float grad_scale,\n    const int step,\n    adamMode_t mode)\n  {\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float b1 = per_tensor_beta1[tensor_num];\n    float b2 = per_tensor_beta2[tensor_num];\n    float eps = per_tensor_eps[tensor_num];\n    float decay = per_tensor_weight_decay[tensor_num];\n\n    float beta1_correction = 1.0f, beta2_correction = 1.0f;\n    if (per_tensor_bias_correction[tensor_num] == 1) {\n      beta1_correction = 1 - std::pow(b1, step);\n      beta2_correction = 1 - std::pow(b2, step);\n    }\n\n    T* p = (T *)tl.addresses[0][tensor_loc];\n    p += chunk_idx*chunk_size;\n    T* m = (T *)tl.addresses[1][tensor_loc];\n    m += chunk_idx*chunk_size;\n    T* v = (T *)tl.addresses[2][tensor_loc];\n    v += chunk_idx*chunk_size;\n    GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];\n    g += chunk_idx*chunk_size;\n    GRAD_T* p_copy = NULL;\n    if (DEPTH == 5) {\n      p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];\n      p_copy += chunk_idx*chunk_size;\n    }\n\n    n -= chunk_idx*chunk_size;\n    \n    T incoming_p[ILP];\n    T incoming_m[ILP];\n    T incoming_v[ILP];\n    T incoming_g[ILP];\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 &&\n      chunk_size % ILP == 0 &&\n      is_aligned(p) &&\n      is_aligned(m) &&\n      is_aligned(v) &&\n      is_aligned(g) &&\n      is_aligned(p_copy)) {\n      for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        GRAD_T tmp_g[ILP];\n        load_store(incoming_p, p, 0, i_start);\n        load_store(incoming_m, m, 0, i_start);\n        load_store(incoming_v, v, 0, i_start);\n        load_store(tmp_g, g, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          incoming_g[ii] = static_cast<T>(tmp_g[ii]);\n          T scaled_grad = incoming_g[ii]/grad_scale;\n          incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n          incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n          T next_m_unbiased = incoming_m[ii] / beta1_correction;\n\t  T next_v_unbiased = incoming_v[ii] / beta2_correction;\n\t  float denom;\n          if (mode == ADAM_MODE_0)\n            denom = sqrtf(next_v_unbiased + eps);\n          else // Mode 1\n            denom = sqrtf(next_v_unbiased) + eps;\n          float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);\n          incoming_p[ii] = incoming_p[ii] - (lr * update);\n\t  if (DEPTH == 5)  tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);\n        }\n        load_store(p, incoming_p, i_start, 0);\n        load_store(m, incoming_m, i_start, 0);\n        load_store(v, incoming_v, i_start, 0);\n        if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);\n      }\n    } else {\n      for (int i_start = 0;\n          i_start < n && i_start < chunk_size;\n          i_start += blockDim.x*ILP) {\n\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          incoming_p[ii] = 0;\n          incoming_m[ii] = 0;\n          incoming_v[ii] = 0;\n          incoming_g[ii] = 0;\n\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if (i < n && i < chunk_size) {\n            incoming_p[ii] = p[i];\n            incoming_m[ii] = m[i];\n            incoming_v[ii] = v[i];\n            incoming_g[ii] = static_cast<T>(g[i]);\n          }\n        }\n\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int j = i_start + threadIdx.x + ii*blockDim.x;\n\n          if (j < n && j < chunk_size) {\n            T scaled_grad = incoming_g[ii]/grad_scale;\n            m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n            v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n            T next_m_unbiased = m[j] / beta1_correction;\n            T next_v_unbiased = v[j] / beta2_correction;\n\t    float denom;\n            if (mode == ADAM_MODE_0)\n              denom = sqrtf(next_v_unbiased + eps);\n            else // Mode 1\n              denom = sqrtf(next_v_unbiased) + eps;\n            float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);\n            p[j] = incoming_p[ii] - (lr * update);\n\t    if (DEPTH == 5)  p_copy[j] = (GRAD_T) p[j];\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_fused_adam_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,  // p, m, v, g, p_copy\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor per_tensor_eps,\n  at::Tensor per_tensor_weight_decay,\n  float lr,\n  float grad_scale,\n  int step,\n  int mode)\n{\n  using namespace at;\n\n  size_t tl_sz = tensor_lists.size();\n  AT_ASSERTM(tl_sz == 4 || tl_sz == 5, \"expected tensor lists of size 4 or 5\");\n\n  if (tl_sz == 5) {\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"dist_adam_cuda_kernel\",  // g\n      using accscalar_t = at::acc_type<scalar_t_0, true>;\n      multi_tensor_apply<5>(\n        BLOCK_SIZE,\n        chunk_size,\n        noop_flag,\n        tensor_lists,\n        DistAdamFunctor<5, accscalar_t, scalar_t_0>(),\n        per_tensor_beta1.DATA_PTR<float>(),\n        per_tensor_beta2.DATA_PTR<float>(),\n        per_tensor_bias_correction.DATA_PTR<int>(),\n        per_tensor_eps.DATA_PTR<float>(),\n        per_tensor_weight_decay.DATA_PTR<float>(),\n        lr,\n        grad_scale,\n        step,\n        (adamMode_t) mode);\n    );\n  } else {\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"dist_adam_cuda_kernel\",  // g\n      using accscalar_t = at::acc_type<scalar_t_0, true>;\n      multi_tensor_apply<4>(\n        BLOCK_SIZE,\n        chunk_size,\n        noop_flag,\n        tensor_lists,\n        DistAdamFunctor<4, accscalar_t, scalar_t_0>(),\n        per_tensor_beta1.DATA_PTR<float>(),\n        per_tensor_beta2.DATA_PTR<float>(),\n        per_tensor_bias_correction.DATA_PTR<int>(),\n        per_tensor_eps.DATA_PTR<float>(),\n        per_tensor_weight_decay.DATA_PTR<float>(),\n        lr,\n        grad_scale,\n        step,\n        (adamMode_t) mode);\n    );\n  }\n  THCudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_compute_update_term_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_beta3,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor step,\n  at::Tensor per_tensor_epsilon,\n  const int mode,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_scale,\n  at::Tensor global_grad_norm,\n  const float max_grad_norm);\n\nvoid multi_tensor_lamb_update_weights_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_param_norm,\n  at::Tensor per_tensor_update_norm,\n  at::Tensor update_norm_offset,\n  at::Tensor learning_rate,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_grad_norm,\n  bool use_nvlamb);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"multi_tensor_lamb_compute_update_term\", &multi_tensor_lamb_compute_update_term_cuda,\n        \"Computes update term for LAMB optimizer\");\n  m.def(\"multi_tensor_lamb_update_weights\", &multi_tensor_lamb_update_weights_cuda,\n        \"Applies update term for LAMB optimizer\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"type_shim.h\"\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate<typename T>\n__device__ __forceinline__ bool is_aligned(T* p){\n  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename FROM_T, typename TO_T> \n__device__ void convert(const FROM_T vi, TO_T& vo)\n{\n    vo = static_cast<TO_T>(vi);\n}\n\ntemplate <>\n__device__ void convert(const float vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = vi;\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, float& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = static_cast<float>(t.as_half);\n}\n\ntemplate <>\n__device__ void convert(const at::Half vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = static_cast<float>(vi);\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, at::Half& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = t.as_half;\n}\n\ntypedef enum{\n  MOMENT_MODE_0   =0, // L2 regularization mode\n  MOMENT_MODE_1   =1  // Decoupled weight decay mode\n} adamMode_t;\n\ntemplate<typename T, typename GRAD_T, typename MATH_T>\nstruct DistOptLAMBStage1Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<5>& tl,\n    const MATH_T* per_tensor_beta1,\n    const MATH_T* per_tensor_beta2,\n    const MATH_T* per_tensor_beta3,\n    const int* per_tensor_bias_correction,\n    const int* step,\n    const MATH_T* per_tensor_epsilon,\n    adamMode_t mode,\n    const MATH_T* per_tensor_decay,\n    const MATH_T* global_scale,\n    const MATH_T* global_grad_norm,\n    const float max_grad_norm)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    if (*noop_gmem == 1)\n        return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float combined_scale = *global_scale;\n    if (max_grad_norm > 0) {\n        combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);\n\tcombined_scale = *global_scale / std::min((float) 1.0, combined_scale);\n    }\n\n    MATH_T beta1 = per_tensor_beta1[tensor_num];\n    MATH_T beta2 = per_tensor_beta2[tensor_num];\n    MATH_T beta3 = 1 - beta1;\n    MATH_T beta1_correction, beta2_correction;\n    if (per_tensor_bias_correction[tensor_num] == 1) {\n        beta1_correction = 1 - pow(beta1, *step);\n        beta2_correction = 1 - pow(beta2, *step);\n    } else {\n        beta1_correction = (MATH_T) 1.0;\n        beta2_correction = (MATH_T) 1.0;\n    }\n    MATH_T epsilon = per_tensor_epsilon[tensor_num];\n    MATH_T decay = per_tensor_decay[tensor_num];\n\n    GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx*chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx*chunk_size;\n\n    MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc];\n    u += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    MATH_T r_g[ILP];\n    MATH_T r_p[ILP];\n    MATH_T r_m[ILP];\n    MATH_T r_v[ILP];\n    // to make things simple, we put aligned case in a different code path\n    if(n % ILP == 0 &&\n       chunk_size % ILP == 0 &&\n       is_aligned(g) &&\n       is_aligned(p) &&\n       is_aligned(m) &&\n       is_aligned(v))\n    {\n      GRAD_T l_g[ILP];\n      T l_p[ILP];\n      T l_m[ILP];\n      T l_v[ILP];\n      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)\n      {\n        // load\n        load_store(l_g, g, 0, i_start);\n        if (decay != 0)\n          load_store(l_p, p, 0, i_start);\n        load_store(l_m, m, 0, i_start);\n        load_store(l_v, v, 0, i_start);\n        // unpack\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          r_g[ii] = l_g[ii];\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          }\n          else {\n            r_p[ii] = l_p[ii];\n          }\n          r_m[ii] = l_m[ii];\n          r_v[ii] = l_v[ii];\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay*r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          }\n          else {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          l_m[ii] = r_m[ii];\n          l_v[ii] = r_v[ii];\n        }\n        // store\n        load_store(u, r_p, i_start, 0);\n        load_store(m, l_m, i_start, 0);\n        load_store(v, l_v, i_start, 0);\n      }\n    }\n    else\n    {\n      // see note in multi_tensor_scale_kernel.cu\n      for(int i_start = 0;\n          i_start < n && i_start < chunk_size;\n          i_start += blockDim.x*ILP)\n      {\n        MATH_T r_g[ILP];\n        MATH_T r_p[ILP];\n        MATH_T r_m[ILP];\n        MATH_T r_v[ILP];\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            r_g[ii] = g[i];\n            // special ?optimization? for lamb stage 1\n            if (decay == 0) {\n              r_p[ii] = MATH_T(0);\n            }\n            else {\n              r_p[ii] = p[i];\n            }\n            r_m[ii] = m[i];\n            r_v[ii] = v[i];\n          } else {\n            r_g[ii] = MATH_T(0);\n            r_p[ii] = MATH_T(0);\n            r_m[ii] = MATH_T(0);\n            r_v[ii] = MATH_T(0);\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay*r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          }\n          else {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            u[i] = r_p[ii];\n            m[i] = r_m[ii];\n            v[i] = r_v[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate<typename T, typename GRAD_T, typename MATH_T>\nstruct DistOptLAMBStage2Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<3>& tl,\n    const MATH_T* per_tensor_param_norm,\n    const MATH_T* per_tensor_update_norm,\n    const long* update_norm_offset,\n    const MATH_T* learning_rate,\n    const MATH_T* per_tensor_decay,\n    const MATH_T* global_grad_norm,\n    bool use_nvlamb)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    if (*noop_gmem == 1)\n        return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T decay = per_tensor_decay[tensor_num];\n\n    MATH_T ratio = *learning_rate;\n    // nvlamb: apply adaptive learning rate to all parameters\n    // otherwise, only apply to those with non-zero weight decay\n    if (use_nvlamb || (decay != (MATH_T) 0.0))\n    {\n      MATH_T param_norm = per_tensor_param_norm[tensor_num];\n      MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];\n      ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);\n    }\n\n    MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc];\n    p_copy += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    // to make things simple, we put aligned case in a different code path\n    if(n % ILP == 0 &&\n       chunk_size % ILP == 0 &&\n       is_aligned(p) &&\n       is_aligned(update))\n    {\n      T r_p[ILP];\n      MATH_T r_update[ILP];\n      GRAD_T r_p_copy[ILP];\n      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)\n      {\n        // load\n        load_store(r_p, p, 0, i_start);\n        load_store(r_update, update, 0, i_start);\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n\t  r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);\n          convert(r_p[ii], r_p_copy[ii]);\n        }\n        load_store(p, r_p, i_start, 0);\n        load_store(p_copy, r_p_copy, i_start, 0);\n      }\n    }\n    else\n    {\n      for(int i_start = 0;\n          i_start < n && i_start < chunk_size;\n          i_start += blockDim.x*ILP)\n      {\n        MATH_T r_p[ILP];\n        MATH_T r_update[ILP];\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            r_p[ii] = p[i];\n            r_update[ii] = update[i];\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          r_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            p[i] = r_p[ii];\n            convert(r_p[ii], p_copy[i]);\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_compute_update_term_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_beta3,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor step,\n  at::Tensor per_tensor_epsilon,\n  const int mode,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_scale,\n  at::Tensor global_grad_norm,\n  const float max_grad_norm)\n{\n  using namespace at;\n\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, \"lamb_stage_1\",\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 1, \"lamb_stage_1\",\n      DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, \"lamb_stage_1\",\n        multi_tensor_apply<5>(\n          BLOCK_SIZE,\n          chunk_size,\n          noop_flag,\n          tensor_lists,\n          DistOptLAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n          per_tensor_beta1.DATA_PTR<scalar_t_2>(),\n          per_tensor_beta2.DATA_PTR<scalar_t_2>(),\n          per_tensor_beta3.DATA_PTR<scalar_t_2>(),\n          per_tensor_bias_correction.DATA_PTR<int>(),\n          step.DATA_PTR<int>(),\n          per_tensor_epsilon.DATA_PTR<scalar_t_2>(),\n          (adamMode_t) mode,\n          per_tensor_decay.DATA_PTR<scalar_t_2>(),\n          global_scale.DATA_PTR<scalar_t_2>(),\n\t  global_grad_norm.DATA_PTR<scalar_t_2>(),\n\t  max_grad_norm); )))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid multi_tensor_lamb_update_weights_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_param_norm,\n  at::Tensor per_tensor_update_norm,\n  at::Tensor update_norm_offset,\n  at::Tensor learning_rate,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_grad_norm,\n  bool use_nvlamb)\n{\n  using namespace at;\n\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, \"lamb_stage_2\",\n    DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[2][0].scalar_type(), 1, \"lamb_stage_2\",\n      DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 2, \"lamb_stage_2\",\n        multi_tensor_apply<3>(\n          BLOCK_SIZE,\n          chunk_size,\n          noop_flag,\n          tensor_lists,\n          DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n          per_tensor_param_norm.DATA_PTR<scalar_t_2>(),\n          per_tensor_update_norm.DATA_PTR<scalar_t_2>(),\n          update_norm_offset.DATA_PTR<long>(),\n\t  learning_rate.DATA_PTR<scalar_t_2>(),\n          per_tensor_decay.DATA_PTR<scalar_t_2>(),\n\t  global_grad_norm.DATA_PTR<scalar_t_2>(),\n          use_nvlamb); )))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/transducer/transducer_joint.cpp",
    "content": "#include <torch/extension.h>\n#include <ATen/Functions.h>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> transducer_joint_cuda_forward(\n    torch::Tensor f,\n    torch::Tensor g,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int64_t packedBatch,\n    int opt,\n    bool packOutput,\n    bool relu,\n    bool dropout,\n    float dropoutProb,\n    int tileSize);\n\n\nstd::vector<torch::Tensor> transducer_joint_cuda_backward(\n    std::vector<torch::Tensor> in,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int maxGLen,\n    bool packOutput,\n    float scale);\n\nstd::vector<torch::Tensor> transducer_joint_forward(\n    torch::Tensor f,\n    torch::Tensor g,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int64_t packedBatch,\n    int opt,\n    bool packOutput,\n    bool relu,\n    bool dropout,\n    float dropoutProb,\n    int tileSize) {\n    CHECK_INPUT(f);\n    CHECK_INPUT(g);\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(gLen);\n    if (packOutput)\n        CHECK_INPUT(batchOffset);\n    return transducer_joint_cuda_forward(\n        f, \n        g, \n        fLen, \n        gLen,\n        batchOffset,\n        packedBatch,\n        opt,\n        packOutput,\n        relu,\n        dropout,\n        dropoutProb,\n        tileSize);\n}\n\nstd::vector<torch::Tensor> transducer_joint_backward(\n    std::vector<torch::Tensor> in,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int maxGLen,\n    bool packOutput,\n    float scale) {\n    for (auto t : in){\n        CHECK_INPUT(t);\n    }\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(gLen);\n    if (packOutput)\n        CHECK_INPUT(batchOffset);\n    return transducer_joint_cuda_backward(\n        in, \n        fLen, \n        gLen,\n        batchOffset,\n        maxFLen,\n        maxGLen,\n        packOutput,\n        scale);\n}\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &transducer_joint_forward, \"transducer joint forward (CUDA)\");\n  m.def(\"backward\", &transducer_joint_backward, \"transducer joint backward (CUDA)\");\n}"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/transducer/transducer_joint_kernel.cu",
    "content": "#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <c10/macros/Macros.h>\n#include <THC/THC.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/CUDAGeneratorImpl.h>\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n#include <curand_kernel.h>\n#include \"philox.h\"\n\n// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.\n// width should be a power of 2 and should be less than warpSize.\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){\n    for (unsigned offset = width/2; offset > 0; offset /= 2){\n        x += __shfl_down_sync(0xffffffff, x, offset, width);   \n    }\n    return x;\n}\n\ninline int largestPowerOfTwo(int x){\n    int y = 1;\n    while (y <= x)\n        y <<= 1;\n    return y >> 1;\n}\n\n/*\nFigure out vectorization type for masks.\nSimilar to how PyTorch figures out acc_t here:\naten/src/ATen/AccumulateType.h \n*/\ntemplate <int V>\nstruct MaskVecType { };\n\ntemplate <> struct MaskVecType<1> { using type = uint8_t; };\ntemplate <> struct MaskVecType<2> { using type = uint16_t; };\ntemplate <> struct MaskVecType<4> { using type = uint32_t; };\n\ntemplate<int V>\nusing mvec_type = typename MaskVecType<V>::type;\n\n// Helper class to calculate pointer offset that can be shared by different flavors of kernels.\n// For fwd, batch offset and stride are different for packing and non-packing mode.\nstruct OffsetCalFwd{\n    __device__ __forceinline__ OffsetCalFwd(\n        int64_t batch, \n        const int64_t *batchOffset, \n        int64_t maxFLen, \n        int64_t maxGLen, \n        int64_t gLen,\n        int64_t hiddenSize,\n        bool packOutput) :\n        batch(batch),\n        batchOffset(batchOffset),\n        maxFLen(maxFLen),\n        maxGLen(maxGLen),\n        gLen(gLen),\n        hiddenSize(hiddenSize),\n        packOutput(packOutput)\n        {}\n    \n    int64_t batch;\n    const int64_t *batchOffset;\n    int64_t maxFLen;\n    int64_t maxGLen;\n    int64_t gLen;\n    int64_t hiddenSize;\n    bool packOutput;\n\n    __device__ __forceinline__ int64_t getBatchOffset(){\n        return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize \n                            : batch*maxFLen*maxGLen*hiddenSize;\n    }\n\n    __device__ __forceinline__ int64_t getStrideF(){\n        return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize;\n    }\n\n    \n};\n\n// Helper class to calculate pointer offset that can be shared by different flavors of kernels\n// For bwd, batch offset and stride are different for packing and non-packing mode.\n// The reducion is done for two input tensors. Therefore, generating two sets of offsets\n// according to bwdFasterDim can lead to a unified implementation in the actual kernel.\nstruct OffsetCalBwd{\n    __device__ __forceinline__ OffsetCalBwd(\n        int64_t batch, \n        const int64_t *batchOffset, \n        const int *fLen, \n        const int *gLen,\n        int64_t maxFLen, \n        int64_t maxGLen, \n        int64_t hiddenSize,\n        bool packOutput,\n        bool bwdFasterDim) :\n        batch(batch),\n        batchOffset(batchOffset),\n        maxFLen(maxFLen),\n        maxGLen(maxGLen),\n        fLen(fLen),\n        gLen(gLen),\n        hiddenSize(hiddenSize),\n        packOutput(packOutput),\n        bwdFasterDim(bwdFasterDim)\n        {}\n\n    int64_t batch;\n    const int64_t *batchOffset;\n    const int *fLen;\n    const int *gLen;\n    int64_t maxFLen;\n    int64_t maxGLen;\n    int64_t hiddenSize;\n    bool packOutput;\n    bool bwdFasterDim;  // whether doing bwd on the faster moving dimension\n\n    __device__ __forceinline__ int64_t getBatchOffset(){\n        return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize \n                            : batch*maxFLen*maxGLen*hiddenSize;\n    }\n\n    __device__ __forceinline__ int64_t getMaxXLen(){\n        return bwdFasterDim ? maxGLen : maxFLen;\n    }\n\n    __device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){\n        return bwdFasterDim ? gLen[batch] : fLen[batch];\n    }\n\n    __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){\n        return bwdFasterDim ? fLen[batch] : gLen[batch];\n    }\n    \n    __device__ __forceinline__ int64_t getStrideX(){\n        return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize);\n    }\n\n    __device__ __forceinline__ int64_t getStrideY(){\n        return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize;\n    }\n};\n\n\n// Vanila transducer joint forward kernel\n// Detail of this joint function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n\n// f is a tensor of shape [batch, T, H]\n// g is a tensor of shape [batch, U, H]\n// the transducer joint does\n// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)\n// The resultant tensor is of shape [batch, T, U, H]\n// Each thread block is working on one \"batch\" of data in the output tensor, [batch, t, u, :]\n\n// This joint function can optionally pack the output where the output tensor with a shape of\n// [B, T, U, H] is packed into [B_packed, H].\n// Don't-care region (t > fLen) or (u > gLen) is removed.\n// To enable packing, the starting offset for each batch need to be specified with batchOffset.\ntemplate <typename scalar_t, class OffsetCal>\n__global__ void transducer_joint_forward(\n    const scalar_t *f,\n    const scalar_t *g,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    scalar_t *sum) {\n\n\n    const int batch = blockIdx.z;\n    const int t = blockIdx.y;\n    const int u = blockIdx.x;\n    const auto myFLen = fLen[batch];\n    const auto myGLen = gLen[batch];\n\n    OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);\n    const auto myBatchOffset = offsetCal.getBatchOffset();\n    const auto strideF = offsetCal.getStrideF();\n    scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize;\n    scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize;\n    scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize;\n\n    if (t < myFLen and u < myGLen){\n        #pragma unroll\n        for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){\n            if (h < hiddenSize){\n                mySum[h] = myF[h] + myG[h];\n            }\n        }\n    }\n    else if (packOutput == false and t < maxFLen and u < maxGLen){\n        // Need to write finite data to don't-care region because we instantiate the result tensor\n        // with torch::empty for performance reasons. Even though it is don't-care region, the \n        // contents need to be finite, otherwise could lead to NaN in WGRAD.\n        // In packing mode, this write is no longer necessary as we remove the don't-care region\n        // from the output.\n        // Picking -1 (over 0) here for ease of testing.\n        #pragma unroll\n        for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){\n            if (h < hiddenSize){\n                mySum[h] = -1;\n            }\n        }    \n    }\n}\n\n/*\nTiled version of the joint forward kernel\nDetail of this joint function can be found in: \n[1] Sequence Transduction with Recurrent Neural Networks.\n\nf is a tensor of shape [batch, T, H]\ng is a tensor of shape [batch, U, H]\nthe transducer joint does\nsum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)\nThe resultant tensor is of shape [batch, T, U, H]\nEach thread is working on a tile of the shape of tileF x tileG in the result tensor.\nThe input for the tile is first loaded in the register and is reused tileG and tileF times. \n\nThis joint function can optionally pack the output where the output tensor with a shape of\n[B, T, U, H] is packed into [B_packed, H].\nDon't-care region (t > fLen) or (u > gLen) is removed.\nTo enable packing, the starting offset for each batch need to be specified with batchOffset.\n\nOptionally this joint function performs ReLU and/or dropout on the joint output, which is \ncontrolled by arguments relu and dropout, respectively. philoxArgs is argument used for generating\npseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint\nfunction is a masked operation, which is controlled by the template argument masked. In this case, \nmasks are saved to backward.\n*/\ntemplate <typename scalar_t, int tileF, int tileG, int U, class OffsetCal, bool masked>\n__global__ void transducer_joint_tiled_forward(\n    const scalar_t *f,\n    const scalar_t *g,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    int64_t hiddenPerBlock,\n    bool packOutput,\n    bool relu, \n    bool dropout,\n    float p,\n    at::PhiloxCudaState philoxArgs,\n    scalar_t *sum,\n    uint8_t *mask) {\n\n    static_assert(U == 4, \"U has to be 4, as random numbers are generated in batch of 4\");\n\n    const int batch = blockIdx.z;\n    const int t = blockIdx.y * tileF;\n    const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;\n    const int u = blockIdx.x / hiddenBlock * tileG;\n    const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock;\n    const int h = threadIdx.x;\n    const auto myFLen = fLen[batch];\n    const auto myGLen = gLen[batch];\n\n    OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);\n    const auto myBatchOffset = offsetCal.getBatchOffset();\n    const auto strideF = offsetCal.getStrideF();\n\n    scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset;\n    scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset;\n    scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset;\n    uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset;\n\n    // The following code is only needed for dropout. We try to bypass them as much as possible.\n    auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) \n                            : std::make_tuple(static_cast<uint64_t>(0), static_cast<uint64_t>(0));\n    uint64_t tid = masked ? (static_cast<uint64_t>(blockIdx.z)*gridDim.y*gridDim.x + \n                        blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x\n                            : 0;\n    Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); \n    scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0;  \n    bool dropoutMask[U];\n\n    if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){    \n        // register buffers for tiled input reuse\n        scalar_t fBuffer[tileF], gBuffer[tileG];    \n        for (int i = 0; i < tileF; ++i){\n            if (t + i < myFLen)\n                fBuffer[i] = myF[i*hiddenSize + h];\n        }\n        for (int j = 0; j < tileG; ++j){\n            if (u + j < myGLen)\n                gBuffer[j] = myG[j*hiddenSize + h];\n        }\n        #pragma unroll\n        for (int i = 0; i < tileF; ++i){\n            if (t + i < myFLen){\n                #pragma unroll\n                for (int j = 0; j < tileG; ++j){\n                    int idx = i*tileG + j;\n                    if (masked and dropout and idx % U == 0){\n                        // For performance, generate 4 random numbers in one shot\n                        // auto rand4 = curand_uniform4(&state);\n                        auto rand4 = uniform4(ph());\n                        dropoutMask[0] = rand4.x < p;\n                        dropoutMask[1] = rand4.y < p;\n                        dropoutMask[2] = rand4.z < p;\n                        dropoutMask[3] = rand4.w < p;\n                    }\n\n                    if (u + j < myGLen){\n                        scalar_t out = fBuffer[i] + gBuffer[j];\n                        if (masked){\n                            // Apply ReLU here when relu is True\n                            bool localMask = relu ? (out>0) : 1;\n                            localMask = dropout ? localMask & dropoutMask[idx%U] : localMask;\n                            out = dropout ? out*localMask*scale : out*localMask;\n                            myMask[i*strideF + j*hiddenSize + h] = static_cast<uint8_t>(localMask);\n                        }\n                        mySum[i*strideF + j*hiddenSize + h] = out;\n                    }\n                    else if (packOutput == false and u + j < maxGLen)\n                        mySum[i*strideF + j*hiddenSize + h] = -1;\n                }\n            }\n            else if (packOutput == false and t + i < maxFLen){\n                // Again need to write finite data to don't-care region\n                #pragma unroll\n                for (int j = 0; j < tileG; ++j){\n                    if (u + j < maxGLen)\n                        mySum[i*strideF + j*hiddenSize + h] = -1;\n                }\n            }\n        }\n    }\n    else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){\n        // Only need to ensure the finity in normal mode\n        #pragma unroll\n        for (int i = 0; i < tileF; ++i){\n            if (t + i < maxFLen){\n                #pragma unroll\n                for (int j = 0; j < tileG; ++j){\n                    if (u + j < maxGLen)\n                        mySum[i*strideF + j*hiddenSize + h] = -1;\n                }\n            }\n        }\n    }\n}\n\n/*\nBwd operation (reduction) on one input tensor. Since the operation performed for the two input\ntensors are exactly the same, only one kernel is needed, and the different indexing offsets\nand strides are handled by OffsetCalBwd.\n\nWhen packing is enabled in the fwd op, unpacking is needed to restore the gradients in a \nnon-packed form.\n\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, class OffsetCal, bool masked>\n__device__ void transducer_joint_single_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    bool bwdFasterDim,  // whether bwd on the faster moving dimension (u)\n    float scale,\n    scalar_t *inGrad,\n    int yBlockOffset=0) {\n\n\n    const int batch = blockIdx.z;\n    // For the second input tensor, this offset need to be subtracted because the first yBlockOffset\n    // sets of thread blocks are for the first input tensor.\n    const int x = blockIdx.y-yBlockOffset;\n    const int hOffset = blockIdx.x*C10_WARP_SIZE;\n    const int wid = threadIdx.y;\n    const int lid = threadIdx.x;\n    const int numWarp = blockDim.y;\n    extern __shared__ char smem8[];\n    auto smem = reinterpret_cast<acc_t*>(smem8);\n\n    OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, \n                        bwdFasterDim);\n    const auto maxXLen = offsetCal.getMaxXLen();\n    const auto myXLen = offsetCal.getMyXLen();\n    const auto myYLen = offsetCal.getMyYLen();\n    scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;\n    \n    if (x < myXLen){\n        \n        const auto myBatchOffset = offsetCal.getBatchOffset();\n        const auto strideX = offsetCal.getStrideX();\n        const auto strideY = offsetCal.getStrideY();\n        const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;\n        const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr;\n        \n        // Each warp reduces numYPerWarp \"y\" first\n        acc_t warpSum = 0;\n        auto numYPerWarp = (myYLen+numWarp-1)/numWarp;\n        #pragma unroll\n        for (int warpY = 0; warpY < numYPerWarp; ++warpY){\n            auto y = wid*numYPerWarp + warpY;\n            if (y < myYLen and (hOffset+lid) < hiddenSize)\n                if (masked)\n                    warpSum += static_cast<acc_t>(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale;\n                else    \n                    warpSum += myGrad[y*strideY + lid];\n        }\n\n        // transpose partial sum in SMEM and reduce further using warpReduce\n        smem[lid*numWarp + wid] = warpSum;\n        __syncthreads();\n        auto sum = smem[wid*C10_WARP_SIZE + lid];\n        sum = warpReduce(sum, numWarp);\n\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // example of 4 warps (a, b, c, d) with 8 threads per warp\n        // Each warp need 8 / 4 = 2 threads to write the results.\n        if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){\n            if (lid % numWarp == 0){\n                myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum;\n            }\n        }\n    }\n    else if (wid == 0 and hOffset + lid < hiddenSize){\n        // Need to ensure the grad is zero for don't care region\n        myInGrad[lid] = 0;\n    }\n}\n\n/*\nActual bwd (reduction) kernel get launched.\nCall transducer_joint_single_backward twice on two input tensors. \nThe two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op \nuses the rest.\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, class OffsetCal, bool masked>\n__global__ void transducer_joint_combined_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    float scale,\n    scalar_t *fGrad,\n    scalar_t *gGrad) {\n    if (blockIdx.y < maxFLen){\n        transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            false,\n            scale,\n            fGrad);\n    }\n    else{\n        transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            true,\n            scale,\n            gGrad,\n            maxFLen);\n    }  \n}\n\n/*\nVectorized version of transducer_joint_single_backward\nDoing exact same operation as transducer_joint_single_backward except the load and store are\nvectorized.\nWhen packing is enabled in the fwd op, unpacking is needed to restore the gradients in a \nnon-packed form.\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>\n__device__ void transducer_joint_single_vec_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    bool bwdFasterDim,\n    float scale,\n    scalar_t *inGrad,\n    int yBlockOffset=0){\n\n    const int batch = blockIdx.z;\n    const int x = blockIdx.y - yBlockOffset;\n    const int hOffset = blockIdx.x*C10_WARP_SIZE*V;\n    const int wid = threadIdx.y;\n    const int lid = threadIdx.x;\n    const int numWarp = blockDim.y;\n\n    // Figure out the vectorization type for mask\n    using mvec_t = mvec_type<V>;\n\n    OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, \n                        bwdFasterDim);\n    const auto maxXLen = offsetCal.getMaxXLen();\n    const auto myXLen = offsetCal.getMyXLen();\n    const auto myYLen = offsetCal.getMyYLen();\n    scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;\n    extern __shared__ char smem8[];\n    auto smem = reinterpret_cast<acc_t*>(smem8);\n\n    acc_t warpSum[V];\n    scalar_t inBuffer[V];\n    uint8_t maskBuffer[V];\n    scalar_t outBuffer[V];\n    auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);\n    auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);\n\n    if (x < myXLen){\n        const auto myBatchOffset = offsetCal.getBatchOffset();\n        const auto strideX = offsetCal.getStrideX();\n        const auto strideY = offsetCal.getStrideY();\n        const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;\n        const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset\n                                            :nullptr;\n\n        for (int i = 0; i < V; ++i)\n            warpSum[i] = 0;\n\n        // Each warp reduces numYPerWarp \"y\" first\n        auto numYPerWarp = (myYLen+numWarp-1)/numWarp;\n        for (int warpY = 0; warpY < numYPerWarp; ++warpY){\n            auto y = wid*numYPerWarp + warpY;\n            auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY);\n            auto myMaskVec = masked ? reinterpret_cast<mvec_t const *>(myMask + y*strideY)\n                                        : nullptr;\n            auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);\n            auto maskBufferVec = reinterpret_cast<mvec_t*>(maskBuffer);\n            if (hOffset + lid*V < hiddenSize and y < myYLen){\n                *inBufferVec = myGradVec[lid];  // vectorized load\n                if (masked){\n                    *maskBufferVec = myMaskVec[lid];\n                    #pragma unroll\n                    for (int i = 0; i < V; ++i)\n                        warpSum[i] += static_cast<acc_t>(inBuffer[i]) * maskBuffer[i] * scale;\n                }\n                else{\n                    #pragma unroll\n                    for (int i = 0; i < V; ++i)\n                        warpSum[i] += inBuffer[i];\n                }\n            }\n        }\n        \n        // transpose partial sum in SMEM and reduce further using warpReduce\n        for (int i = 0; i < V; ++i){\n            smem[lid*numWarp + wid] = warpSum[i];\n            __syncthreads();\n            auto sum = smem[wid*C10_WARP_SIZE + lid];\n\n            if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){\n                sum = warpReduce(sum, numWarp);\n                if (lid % numWarp == 0){\n                    outBuffer[i] = sum;\n                }\n            }\n            __syncthreads();\n        }\n\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // example of 4 warps (a, b, c, d) with 8 threads per warp\n        // Each warp need 8 / 4 = 2 threads to write the results.\n        if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize)\n            myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec;     \n    }\n    else if (wid == 0 and hOffset + lid*V < hiddenSize){\n        // Need to ensure the grad is zero for don't care region\n        myInGradVec[lid] = 0;\n    }\n}\n\n/*\nVecotrized version of transducer_joint_combined_backward\nCall transducer_joint_single_vec_backward twice on two input tensors. \nThe two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op \nuses the rest.\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>\n__global__ void transducer_joint_combined_vec_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    float scale,\n    scalar_t *fGrad,\n    scalar_t *gGrad) {\n    if (blockIdx.y < maxFLen){\n        transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            false,\n            scale,\n            fGrad);\n    }\n    else{\n        transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            true,\n            scale,\n            gGrad,\n            maxFLen);\n    }  \n}\n\n\n\n\nstd::vector<torch::Tensor> transducer_joint_cuda_forward(\n    torch::Tensor f,\n    torch::Tensor g,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int64_t packedBatch,\n    int opt,\n    bool packOutput,\n    bool relu,\n    bool dropout,\n    float dropoutProb,\n    int tileSize){\n\n    \n    auto tensorOpt = f.options();\n    auto dtype = f.scalar_type();\n    const auto batchSize = f.size(0);\n    const auto maxFLen = f.size(1);\n    const auto maxGLen = g.size(1);\n    const auto hiddenSize = f.size(2);\n    bool masked = dropout or relu;\n    \n    int64_t *batchOffsetPtr = nullptr;\n    torch::Tensor sum, mask;\n    auto maskOpt = tensorOpt.dtype(torch::kUInt8);\n    if (!packOutput){\n        sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);\n        batchOffsetPtr = nullptr;\n        if (masked)\n            mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt);\n    }\n    else{\n        sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);    \n        batchOffsetPtr = batchOffset.data_ptr<int64_t>();\n        if (masked)\n            mask = torch::empty({packedBatch, hiddenSize}, maskOpt);\n    }\n    uint8_t *maskPtr = masked ? mask.data_ptr<uint8_t>() : nullptr;\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    TORCH_CHECK(opt == 0 or opt == 1, \"Got an invalid optimization level \", opt);\n    // Simple heuristics\n    const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)\n                                        / C10_WARP_SIZE * C10_WARP_SIZE);\n    \n    if (opt == 0){\n        // vanilla kernel\n        const int threads = numThread;\n        const dim3 blocks(maxGLen, maxFLen, batchSize);\n\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_joint_forward\", ([&] {\n            transducer_joint_forward<scalar_t, OffsetCalFwd>\n            <<<blocks, threads, 0, stream>>>(\n                f.data_ptr<scalar_t>(), \n                g.data_ptr<scalar_t>(), \n                fLen.data_ptr<int>(), \n                gLen.data_ptr<int>(), \n                batchOffsetPtr,\n                maxFLen,\n                maxGLen,\n                hiddenSize,\n                packOutput,\n                sum.data_ptr<scalar_t>());\n        }));  \n    }\n    if (opt == 1){\n        // tiled version. For simplicity, assume tileF == tileG, even though the kernel can \n        // support more general cases.\n        const int threads = numThread;\n        const int hiddenPerBlock = numThread;\n        const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;\n        const dim3 blocks(  (maxGLen+tileSize-1)/tileSize * hiddenBlock, \n                            (maxFLen+tileSize-1)/tileSize, \n                            batchSize);\n\n        TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, \n                \"Expected tileSize to be in [1, 2, 4], but got \", tileSize);\n\n        at::PhiloxCudaState rng_engine_inputs;\n        if (masked){\n            // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler \n            // for non-masked calls.\n            // Therefore no need to initialize.\n            c10::optional<at::Generator> gen_;\n            auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, \n                                                    at::cuda::detail::getDefaultCUDAGenerator());\n            // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, \n            // each thread processes tileF * tileG output elements. \n            int64_t counterOffset = tileSize * tileSize;\n            {\n                std::lock_guard<std::mutex> lock(gen->mutex_);\n                rng_engine_inputs = gen->philox_cuda_state(counterOffset);\n            }\n        }\n\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_joint_forward\", ([&] {\n            void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, \n                            int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, \n                            at::PhiloxCudaState, scalar_t*, uint8_t*);\n            if (masked){\n                switch (tileSize){\n                    case 2:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd, \n                                                                    true>;\n                        break;\n                    case 4:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd, \n                                                                    true>;\n                        break;\n                }\n            }\n            else{\n                switch (tileSize){\n                    case 1:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 1, 1, 4, OffsetCalFwd, \n                                                                    false>;\n                        break;\n                    case 2:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd, \n                                                                    false>;\n                        break;\n                    case 4:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd, \n                                                                    false>;\n                        break;\n                }\n            }\n            \n            kernel<<<blocks, threads, 0, stream>>>(\n                f.data_ptr<scalar_t>(),\n                g.data_ptr<scalar_t>(),\n                fLen.data_ptr<int>(),\n                gLen.data_ptr<int>(),\n                batchOffsetPtr,\n                maxFLen,\n                maxGLen,\n                hiddenSize,\n                hiddenPerBlock,\n                packOutput,\n                relu,\n                dropout,\n                1.0f - dropoutProb,\n                rng_engine_inputs,\n                sum.data_ptr<scalar_t>(),\n                maskPtr);\n        }));  \n    }\n \n    THCudaCheck(cudaGetLastError());\n    if (masked) \n        return {sum, mask};\n    else\n        return {sum};\n}\n\nstd::vector<torch::Tensor> transducer_joint_cuda_backward(\n    std::vector<torch::Tensor> in,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int maxGLen,\n    bool packOutput,\n    float scale){\n\n    auto grad = in[0];\n    bool masked = (in.size() == 2);\n    uint8_t *maskPtr = masked ? in[1].data_ptr<uint8_t>() : nullptr;\n\n    auto tensorOpt = grad.options();\n    auto dtype = grad.scalar_type();\n    const int batchSize = fLen.size(0);\n    const int hiddenSize = grad.size(-1);\n\n    const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n    const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;\n\n    torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);\n    torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);\n\n    int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr<int64_t>(); \n\n    // The number \"y\" I would like each thread to work on\n    const int workPerThread = 32;   \n    // Since the bwd for f and g have the same thread block size, we need to use the max of the two.\n    int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread);\n    // Would like to have at least 2 warps \n    numWarp = std::max(2, numWarp);\n    // cap on the maximum number of warps allowed\n    numWarp = std::min(maxNumWarp, numWarp); \n\n    // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape\n    // numWarp x warpSize\n    const int smemSize = numWarp * C10_WARP_SIZE;\n    const dim3 threads(C10_WARP_SIZE, numWarp, 1);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_joint_cuda_backward_kernel\", ([&] {\n        auto gradPtr = grad.data_ptr<scalar_t>();\n        auto fLenPtr = fLen.data_ptr<int>();\n        auto gLenPtr = gLen.data_ptr<int>(); \n        auto fGradPtr = fGrad.data_ptr<scalar_t>();\n        auto gGradPtr = gGrad.data_ptr<scalar_t>();\n\n        // resolve the acc_t type\n        using acc_t = at::acc_type<scalar_t, true>;\n        using vec_t = uint64_t;\n\n        constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);\n        constexpr int vecAlignment = std::alignment_of<vec_t>::value;\n\n        // if all input and output tensors meet the alignment requirement\n        bool memAlign = (reinterpret_cast<uint64_t>(gradPtr) % vecAlignment == 0) \n                        and (reinterpret_cast<uint64_t>(fGradPtr) % vecAlignment == 0) \n                        and (reinterpret_cast<uint64_t>(gGradPtr) % vecAlignment == 0);\n\n        if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){\n            // If vectorization helps and the alignment requirement is met, use the vectorized \n            // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.\n            const dim3 blocks(  (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), \n                                maxFLen+maxGLen, \n                                batchSize);\n            if (masked){\n                transducer_joint_combined_vec_backward\n                    <scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, true>\n                    <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);\n            }\n            else{\n                transducer_joint_combined_vec_backward\n                <scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, false>\n                <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);    \n            }\n        }\n        else{\n            const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, \n                                maxFLen + maxGLen, batchSize);\n            if (masked){\n                transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>\n                <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);\n            }\n            else{\n                transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, false>\n                <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);\n            }\n        }\n    }));   \n\n    return {fGrad, gGrad};\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/transducer/transducer_loss.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> transducer_loss_cuda_forward(\n    torch::Tensor x,\n    torch::Tensor label,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool packedInput);\n\ntorch::Tensor transducer_loss_cuda_backward(\n    torch::Tensor x,\n    torch::Tensor lossGrad,\n    torch::Tensor alpha,\n    torch::Tensor beta,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor label,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool fuseSoftmaxBackward,\n    bool packedInput);\n\n\nstd::vector<torch::Tensor> transducer_loss_forward(\n    torch::Tensor x,\n    torch::Tensor label,\n    torch::Tensor fLen,\n    torch::Tensor yLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool packedInput\n    ) {\n\n    CHECK_INPUT(x);\n    CHECK_INPUT(label);\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(yLen);\n    if (packedInput)\n        CHECK_INPUT(batchOffset);\n    return transducer_loss_cuda_forward(\n        x, \n        label, \n        fLen, \n        yLen, \n        batchOffset,\n        maxFLen,\n        blankIdx, \n        opt,\n        packedInput);\n}\n\ntorch::Tensor transducer_loss_backward(\n    torch::Tensor x,\n    torch::Tensor lossGrad,\n    torch::Tensor alpha,\n    torch::Tensor beta,\n    torch::Tensor fLen,\n    torch::Tensor yLen,\n    torch::Tensor label,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool fuseSoftmaxBackward,\n    bool packedInput){\n\n    CHECK_INPUT(x);\n    CHECK_INPUT(label);\n    CHECK_INPUT(lossGrad);\n    CHECK_INPUT(alpha);\n    CHECK_INPUT(beta);\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(yLen);\n    if (packedInput)\n        CHECK_INPUT(batchOffset);\n\n    return transducer_loss_cuda_backward(\n        x,\n        lossGrad,\n        alpha,\n        beta,\n        fLen,\n        yLen,\n        label,\n        batchOffset,\n        maxFLen,\n        blankIdx,\n        opt,\n        fuseSoftmaxBackward,\n        packedInput);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &transducer_loss_forward, \"transducer loss forward (CUDA)\");\n  m.def(\"backward\", &transducer_loss_backward, \"transducer loss backward (CUDA)\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/transducer/transducer_loss_kernel.cu",
    "content": "#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <vector>\n#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <THC/THC.h>\n#include <ATen/cuda/CUDAContext.h>\n\ntemplate<typename scalar_t>\n__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {\n    // standard log-sum-exp trick is used here to provide better numerical stability\n    return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b));\n}\n\n// Vanilla transducer loss function (i.e. forward-backward algorithm)\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n\n// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted\n// into log scale by the preceding log_softmax layer\n// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. \n// alpha and beta are of acc_t type, as they are essentially accumulators.\n\n// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into \n// [B_packed, H].\n// Don't-care region (t > audLen) or (u > txtLen) is removed.\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_forward(\n    const scalar_t* x,\n    const int* label,\n    const int* audLen,\n    const int* txtLen,\n    const int64_t* batchOffset,\n    int64_t dictSize,   // 64-bit indexing for data tensor\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    acc_t* alpha,\n    acc_t* beta,\n    scalar_t* loss) {\n\n    const int batch = blockIdx.y;\n    const int tid = threadIdx.x;\n    const auto myFLen = audLen[batch];\n    // Note that start of the sentence is added as 1 here\n    const auto myGLen = txtLen[batch] + 1;  \n    const auto myLabel = label + batch * (maxGLen-1);\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n    const scalar_t* myX = x + myBatchOffset * dictSize; \n    int u  = tid;\n\n    if (blockIdx.x == 0){\n        // alpha path\n        acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;\n        if (u == 0) \n            myAlpha[0] = 0;\n        __syncthreads();\n\n        for (int64_t step = 1; step < myFLen+myGLen-1; ++step){\n            // Move along the diagonal wavefront to leverage available parallelism\n            for (u = tid; u < myGLen; u += blockDim.x){\n                int64_t t = step - u;\n                if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                    // Eq(16) in [1]\n                    if (u == 0){\n                        // alpha(t, u) = alpha(t-1, u) * null(t-1, u)\n                        myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen] \n                                                    + myX[((t-1)*myStrideT) * dictSize + blankIdx];\n                    }\n                    else if (t == 0){\n                        // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)\n                        myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];\n                    }\n                    else{\n                        // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)\n                        acc_t current = myAlpha[(t-1)*maxGLen + u] \n                                        + myX[((t-1)*myStrideT + u) * dictSize + blankIdx];\n                        acc_t next = myAlpha[t*maxGLen + u - 1] \n                                        + myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]];\n                        myAlpha[t*maxGLen + u] = logSumExp(next, current);\n                    }\n                }\n            }\n            __syncthreads();\n        }\n    }\n    else if (blockIdx.x == 1){\n        // beta path\n        acc_t* myBeta = beta + batch*maxFLen*maxGLen;\n        if (u == 0){\n            myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT \n                                                        + myGLen - 1) * dictSize + blankIdx];\n        }\n        __syncthreads();\n\n        for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){\n            for (u = tid; u < myGLen; u += blockDim.x){\n                int64_t t = step - u;\n                if (t >= 0 and t < myFLen and u >=0 and u < myGLen){\n                    // Eq(18) in [1]\n                    if (u == myGLen - 1){\n                        // beta(t, u) = beta(t+1, u) * null(t, u)\n                        myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u] \n                                                + myX[(t*myStrideT + u) * dictSize + blankIdx];\n                    }\n                    else if (t == myFLen - 1){\n                        // beta(t, u) = beta(t, u+1) * y(t, u)\n                        myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1] \n                                                + myX[(t*myStrideT + u) * dictSize + myLabel[u]];\n                    }\n                    else{\n                        // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)\n                        acc_t current = myBeta[(t+1)*maxGLen + u] \n                                        + myX[(t*myStrideT + u) * dictSize + blankIdx];\n                        acc_t next = myBeta[t*maxGLen + u + 1] \n                                        + myX[(t*myStrideT + u) * dictSize + myLabel[u]];\n                        myBeta[t*maxGLen + u] = logSumExp(next, current);\n                    }\n                }\n            }\n            __syncthreads();\n        }\n        if (tid == 0)\n            loss[batch] = -myBeta[0];   \n    }\n\n}\n\n// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.\n// Compared to the vanilla version, there are two optimizations:\n// 1. load x in batch through loop unrolling to reduce the latency.\n// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.\n// For simplicity, this kernel currently only supports U <= maxThread, which should be the common\n// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.\n\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted\n// into log scale by the preceding log_softmax layer\n// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.\n// alpha and beta are of acc_t type, as they are essentially accumulators.\n\n// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into \n// [B_packed, H].\n// Don't-care region (t > audLen) or (u > txtLen) is removed.\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t, int batchLdSize>\n__global__ void transducer_loss_batch_load_forward(\n    const scalar_t* x,\n    const int* label,\n    const int* audLen,\n    const int* txtLen,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    acc_t* alpha,\n    acc_t* beta,\n    scalar_t* loss) {\n\n    const int batch = blockIdx.y;\n    int u  = threadIdx.x;\n    const auto myFLen = audLen[batch];\n    const auto myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n    const scalar_t* myX = x + myBatchOffset * dictSize; \n    scalar_t next[batchLdSize], current[batchLdSize];\n    extern __shared__ char smem8[];\n    auto smem = reinterpret_cast<acc_t*>(smem8);\n\n    if (blockIdx.x == 0){\n        // alpha path\n        acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;\n        // two SMEM regions for double buffering read and write data to avoid data race\n        acc_t * const sharedAlpha[2] = {smem, smem+maxGLen};\n\n        sharedAlpha[0][u] = 0; \n        __syncthreads();\n\n        if (u == 0)\n            myAlpha[0] = 0;\n\n        auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1];\n        // register used to pass value to the next step for the same thread\n        acc_t prvStepAlpha = 0;\n        for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){\n            // Move along the diagonal wavefront to leverage available parallelism\n            // Batch loading X through loop unrolling\n            #pragma unroll\n            for (int i = 0; i < batchLdSize; ++i){\n                if (step+i<myFLen+myGLen-1){\n                    // index computing\n                    int64_t t = step + i - u;\n                    int64_t currentId = ((t-1)*myStrideT + u) * dictSize + blankIdx;\n                    int64_t nextId = (t*myStrideT + u - 1) * dictSize + myAlphaLabel;\n                    // main loading loop\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        if (u == 0){\n                            current[i] = myX[currentId];\n                        }\n                        else if (t == 0){\n                            next[i] = myX[nextId];\n                        }\n                        else{\n                            current[i] = myX[currentId];\n                            next[i] = myX[nextId];\n                        }\n                    }\n                }\n            }\n            // main computing loop\n            for (int i = 0; i < batchLdSize; ++i){\n                // swap the pointer for double buffering\n                auto sharedAlphaRd = sharedAlpha[(step+i-1)%2];\n                auto sharedAlphaWr = sharedAlpha[(step+i)%2];\n                if (step+i<myFLen+myGLen-1){\n                    int64_t t = step + i - u;\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        // Eq(16) in [1]\n                        if (u == 0)\n                            prvStepAlpha = prvStepAlpha+current[i];\n                        else if (t == 0)\n                            prvStepAlpha = sharedAlphaRd[u-1]+next[i];\n                        else\n                            prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1]\n                                            + next[i]);\n                        sharedAlphaWr[u] = prvStepAlpha;\n                        myAlpha[t*maxGLen + u] = prvStepAlpha;\n                    }\n                }\n                __syncthreads();\n            }\n        }\n    }\n    else if (blockIdx.x == 1){\n        // beta path\n        acc_t* myBeta = beta + batch*maxFLen*maxGLen;\n        // two SMEM regions for double buffering read and write data to avoid data race\n        acc_t * const sharedBeta[2] = {smem, smem + maxGLen};\n        sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];\n        __syncthreads();\n\n        auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u];\n        // register used to pass value to the next step for the same thread\n        acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];\n        if (u == 0)\n            myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta;\n\n        for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){\n            // Move along the diagonal wavefront to leverage available parallelism\n            // Batch loading X\n            #pragma unroll\n            for (int i = 0; i < batchLdSize; ++i){\n                if (step+i<myFLen+myGLen-1){\n                    // index computing\n                    int64_t t = myFLen+myGLen - (step + i) - 2 - u;\n                    int64_t currentId = (t*myStrideT + u) * dictSize + blankIdx;\n                    int64_t nextId = (t*myStrideT + u) * dictSize + myBetaLabel;\n                    // main loading loop\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        if (u == myGLen - 1){\n                            current[i] = myX[currentId];\n                        }\n                        else if (t == myFLen - 1){\n                            next[i] = myX[nextId];\n                        }\n                        else{\n                            current[i] = myX[currentId];\n                            next[i] = myX[nextId];\n                        }\n                    }\n                }\n            }\n            // main computing loop\n            for (int i = 0; i < batchLdSize; ++i){\n                // swap the pointer for double buffering\n                auto sharedBetaRd = sharedBeta[(step+i-1)%2];\n                auto sharedBetaWr = sharedBeta[(step+i)%2];\n                if (step+i<myFLen+myGLen-1){\n                    int64_t t = myFLen+myGLen - (step + i) - 2 - u;\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        // Eq(18) in [1]\n                        if (u == myGLen - 1)\n                            prvStepBeta = prvStepBeta+current[i];\n                        else if (t == myFLen - 1)\n                            prvStepBeta = sharedBetaRd[u+1]+next[i];\n                        else\n                            prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1]\n                                            + next[i]);\n                        sharedBetaWr[u] = prvStepBeta;\n                        myBeta[t*maxGLen + u] = prvStepBeta;\n                    }\n                    \n                }\n                __syncthreads();\n            }\n        }\n        if (u == 0)\n            loss[batch] = -prvStepBeta; \n    }\n\n}\n\n// Vanilla transudcer loss backward operation.\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, \n// hence only Eq(20) in [1] is implemented in this kernel.\n\n// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time\n// Since only gradients for the correct token and null token need to be updated, gradients at other\n// locations are initialized to 0.\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_backward(\n    const scalar_t* x,\n    const scalar_t* lossGrad,\n    const int* audLen,\n    const int* txtLen,\n    const int* label,\n    const acc_t* alpha,\n    const acc_t* beta,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    scalar_t* xGrad) {\n\n    const int tid = threadIdx.x;\n    const int t = blockIdx.x;\n    const int batch = blockIdx.y;\n    const int64_t myFLen = audLen[batch];\n    const int64_t myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n    auto myX = x + (myBatchOffset + t*myStrideT)*dictSize;\n    auto myAlpha = alpha + batch*maxFLen*maxGLen;\n    auto myBeta = beta + batch*maxFLen*maxGLen;\n    auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize; \n    auto myLabel = label + batch*(maxGLen-1);\n\n    int64_t u = tid;\n    while (t < myFLen and u < myGLen){\n        // Do the update\n        // loss = -ln(Pr(y*|x))\n        acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];  \n        if (u != myGLen - 1)\n            myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1] \n                                                + myX[u*dictSize + myLabel[u]]);\n        if (t == myFLen - 1 and u == myGLen - 1)\n            myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]);\n        else if (t != myFLen - 1)\n            myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u] \n                                                + myX[u*dictSize + blankIdx]); \n\n        u += blockDim.x;\n    }\n}\n\n// Fused transudcer loss backward operation.\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// The bwd op of the preceding softmax layer is fused in this kernel. \n// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_fused_backward(\n    const scalar_t* x,\n    const scalar_t* lossGrad,\n    const int* audLen,\n    const int* txtLen,\n    const int* label,\n    const acc_t* alpha,\n    const acc_t* beta,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    scalar_t* xGrad) {\n    \n    const int tid = threadIdx.x;\n    const int u = blockIdx.x;\n    const int t = blockIdx.y;\n    const int batch = blockIdx.z;\n    const int64_t myFLen = audLen[batch];\n    const int64_t myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n\n    __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;\n    auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; \n\n    if (t < myFLen and u < myGLen){ \n        auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; \n        auto myAlpha = alpha + batch*maxFLen*maxGLen;\n        auto myBeta = beta + batch*maxFLen*maxGLen;\n        auto myLabel = label + batch*(maxGLen-1);\n\n        // load and store shared variables in SMEM\n        if (tid == 0){\n            commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];\n            myBetaTU = myBeta[t*maxGLen + u];\n            myBetaTUp1 = myBeta[t*maxGLen + u + 1];\n            myBetaTp1U = myBeta[(t+1)*maxGLen + u];\n            myLabelShared = myLabel[u];\n        }\n\n        __syncthreads();\n\n        for (int64_t h = tid; h < dictSize; h += blockDim.x){\n            // Do the update\n            acc_t grad = commonFactor + myX[h];  // loss = -ln(Pr(y*|x))\n            acc_t myGrad = std::exp(grad + myBetaTU);\n            if (u != myGLen - 1 and h == myLabelShared){\n                myGrad -= std::exp(grad + myBetaTUp1);\n            }\n            else if (h == blankIdx){\n                if (t == myFLen - 1 and u == myGLen - 1)\n                    myGrad -= std::exp(grad);\n                else if (t != myFLen - 1)\n                    myGrad -= std::exp(grad + myBetaTp1U);\n            }\n            myXGrad[h] = myGrad;\n        }\n    }\n    else if (!packedInput){\n        // In non-pack mode, need to make sure the gradients for don't-care regions are zero.\n        for (int64_t h = tid; h < dictSize; h += blockDim.x){\n            myXGrad[h] = 0;\n        }\n    }\n}\n\n\n// Vectorized version of fused transudcer loss backward operation.\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// The bwd op of the preceding softmax layer is fused in this kernel. \n// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V>\n__global__ void transducer_loss_fused_vec_backward(\n    const scalar_t* x,\n    const scalar_t* lossGrad,\n    const int* audLen,\n    const int* txtLen,\n    const int* label,\n    const acc_t* alpha,\n    const acc_t* beta,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    scalar_t* xGrad) {\n    \n    const int tid = threadIdx.x;\n    const int u = blockIdx.x;\n    const int t = blockIdx.y;\n    const int batch = blockIdx.z;\n    const int64_t myFLen = audLen[batch];\n    const int64_t myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n\n    __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;\n    auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; \n    auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; \n    auto myAlpha = alpha + batch*maxFLen*maxGLen;\n    auto myBeta = beta + batch*maxFLen*maxGLen;\n    auto myLabel = label + batch*(maxGLen-1);\n\n    // Variabels for vectorization\n    scalar_t myXBuffer[V], myXGradBuffer[V];\n    auto myXVec = reinterpret_cast<vec_t const *>(myX);\n    auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);\n    auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);\n    auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);\n    if (t < myFLen and u < myGLen){ \n        // load and store shared variables in SMEM\n        if (tid == 0){\n            commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];\n            myBetaTU = myBeta[t*maxGLen + u];\n            if (t != myFLen - 1)\n                myBetaTp1U = myBeta[(t+1)*maxGLen + u];\n            if (u != myGLen - 1){\n                myBetaTUp1 = myBeta[t*maxGLen + u + 1];\n                myLabelShared = myLabel[u];\n            }\n        }\n\n        __syncthreads();\n\n        #pragma unroll\n        for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){\n            // Load myX in a vector form\n            *myXBufferVec = myXVec[h0/V];\n            // Do the update for a vector of input\n            #pragma unroll\n            for (int i = 0; i < V; ++i){\n                auto h = h0 + i;\n                acc_t grad = commonFactor + myXBuffer[i];  // loss = -ln(Pr(y*|x))\n                acc_t myGrad = std::exp(grad + myBetaTU);\n                if (u != myGLen - 1 and h == myLabelShared){\n                    myGrad -= std::exp(grad + myBetaTUp1);\n                }\n                else if (h == blankIdx){\n                    if (t == myFLen - 1 and u == myGLen - 1)\n                        myGrad -= std::exp(grad);\n                    else if (t != myFLen - 1)\n                        myGrad -= std::exp(grad + myBetaTp1U);\n                }\n                myXGradBuffer[i] = myGrad;\n            }\n\n            // Store myXGrad in a vector form\n            myXGradVec[h0/V] = *myXGradBufferVec;\n            \n        }\n    }\n    else if (!packedInput){\n        // In non-pack mode, need to make sure the gradients for don't-care regions are zero.\n        for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){\n            myXGradVec[h0/V] = 0;\n        }\n    }\n}\n\n\nstd::vector<torch::Tensor> transducer_loss_cuda_forward(\n    torch::Tensor x,\n    torch::Tensor label,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool packedInput){\n\n    auto scalarType = x.scalar_type();\n    auto tensorOpt = x.options();\n    const int batchSize = label.size(0);\n    const int maxGLen = label.size(1) + 1;\n    const int dictSize = x.size(-1);\n\n    TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, \n                \"Expected blank index to be in the range of 0 to \", \n                dictSize-1,\n                \", but got \", \n                blankIdx);\n    TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, \n                \"Got an invalid optimization level \", \n                opt);\n\n    // The data type of alpha and beta will be resolved at dispatch time,\n    // hence defined here and assigned later\n    torch::Tensor alpha;    \n    torch::Tensor beta;\n    torch::Tensor loss = torch::empty({batchSize}, tensorOpt);\n    const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n    const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;\n    const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;\n    const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, \"transducer_loss_cuda_forward\", ([&] {\n        // resolve accumulation type\n        using acc_t = at::acc_type<scalar_t, true>;\n        auto accType = c10::CppTypeToScalarType<acc_t>::value;\n        auto accTensorOpt = tensorOpt.dtype(accType);\n        alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);\n        beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);\n\n        // decide what kernel to launch based on the problem size\n        // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla\n        // kernel.\n        const auto smemSize = 2*maxGLen*sizeof(acc_t);\n        const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 \n                                    : (opt == -1) ? 1 : opt;\n        const int threads = std::min(maxThreadPerBlock, maxGLen);\n        const dim3 blocks(2, batchSize, 1);        \n\n        if (optFallBack == 0)\n            transducer_loss_forward<<<blocks, threads, 0, stream>>>(\n                x.data_ptr<scalar_t>(), \n                label.data_ptr<int>(), \n                audLen.data_ptr<int>(), \n                txtLen.data_ptr<int>(), \n                batchOffsetPtr,\n                dictSize, \n                blankIdx, \n                maxFLen,\n                maxGLen,\n                packedInput,\n                alpha.data_ptr<acc_t>(), \n                beta.data_ptr<acc_t>(), \n                loss.data_ptr<scalar_t>());\n        else if (optFallBack == 1)\n            transducer_loss_batch_load_forward<scalar_t, acc_t, 4>\n            <<<blocks, threads, smemSize, stream>>>(\n                x.data_ptr<scalar_t>(), \n                label.data_ptr<int>(), \n                audLen.data_ptr<int>(), \n                txtLen.data_ptr<int>(), \n                batchOffsetPtr,\n                dictSize, \n                blankIdx, \n                maxFLen,\n                maxGLen,\n                packedInput,\n                alpha.data_ptr<acc_t>(), \n                beta.data_ptr<acc_t>(), \n                loss.data_ptr<scalar_t>());  \n\n    }));\n    THCudaCheck(cudaGetLastError());\n\n    return {alpha, beta, loss};\n}\n\n\n\n\ntorch::Tensor transducer_loss_cuda_backward(\n    torch::Tensor x,\n    torch::Tensor lossGrad,\n    torch::Tensor alpha,\n    torch::Tensor beta,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor label,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool fuseSoftmaxBackward,\n    bool packedInput){\n\n    auto dtype = x.scalar_type();\n    torch::Tensor xGrad;\n    const int batchSize = label.size(0);\n    const int maxGLen = label.size(1) + 1;\n    const int dictSize = x.size(-1);\n    const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n    const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;\n    const int warpSize = deviceProperties->warpSize;\n    const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    \n    if (fuseSoftmaxBackward){\n        // alloc empty tensors for performance, hence need to ensure zeros are writtern to \n        // don't-care region in the kernel.\n        xGrad = torch::empty_like(x);\n\n        // Would like each thread to work on 4 hidden units\n        const int workPerThread = 4;  \n        // Don't want to have more than 128 threads per thread block\n        const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);\n        const int threads = std::min(maxThreadPerElmt, std::max(warpSize, \n                                    (dictSize+workPerThread-1)/workPerThread));\n        const dim3 blocks(maxGLen, maxFLen, batchSize);\n\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_loss_cuda_backward\", ([&] {\n            using vec_t = uint64_t;\n            using acc_t = at::acc_type<scalar_t, true>;\n            constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);\n            constexpr int vecAlignment = std::alignment_of<vec_t>::value;\n            // if all input and output tensors meet the alignment requirement\n            bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0\n                                and reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>()) \n                                        % vecAlignment == 0;\n\n            if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){\n                transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor>\n                    <<<blocks, threads, 0, stream>>>(    \n                    x.data_ptr<scalar_t>(), \n                    lossGrad.data_ptr<scalar_t>(),\n                    audLen.data_ptr<int>(), \n                    txtLen.data_ptr<int>(), \n                    label.data_ptr<int>(),\n                    alpha.data_ptr<acc_t>(), \n                    beta.data_ptr<acc_t>(),  \n                    batchOffsetPtr,\n                    dictSize, \n                    blankIdx, \n                    maxFLen,\n                    maxGLen,\n                    packedInput,\n                    xGrad.data_ptr<scalar_t>());   \n            }\n            else{\n                transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(    \n                    x.data_ptr<scalar_t>(), \n                    lossGrad.data_ptr<scalar_t>(),\n                    audLen.data_ptr<int>(), \n                    txtLen.data_ptr<int>(), \n                    label.data_ptr<int>(),\n                    alpha.data_ptr<acc_t>(), \n                    beta.data_ptr<acc_t>(),  \n                    batchOffsetPtr,\n                    dictSize, \n                    blankIdx, \n                    maxFLen,\n                    maxGLen,\n                    packedInput,\n                    xGrad.data_ptr<scalar_t>());   \n                \n            }\n        }));\n    }\n    else{\n        // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize\n        // the tensor with all zeros.\n        xGrad = torch::zeros_like(x);\n        // don't launch more threads than needed.\n        const int threads = std::min(maxThreadPerBlock, maxGLen);\n        const dim3 blocks(maxFLen, batchSize);\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_loss_cuda_backward\", ([&] {\n            using acc_t = at::acc_type<scalar_t, true>;\n            transducer_loss_backward<<<blocks, threads, 0, stream>>>(    \n                x.data_ptr<scalar_t>(), \n                lossGrad.data_ptr<scalar_t>(),\n                audLen.data_ptr<int>(), \n                txtLen.data_ptr<int>(), \n                label.data_ptr<int>(),\n                alpha.data_ptr<acc_t>(), \n                beta.data_ptr<acc_t>(), \n                batchOffsetPtr, \n                dictSize, \n                blankIdx, \n                maxFLen,\n                maxGLen,\n                packedInput,\n                xGrad.data_ptr<scalar_t>());\n        }));\n    }\n    THCudaCheck(cudaGetLastError());\n    \n    return xGrad;\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/xentropy/interface.cpp",
    "content": "#include <torch/extension.h>\n\n// CUDA forward declarations\n\nstd::vector<at::Tensor> softmax_xentropy_cuda(\n    const at::Tensor &input,\n    const at::Tensor &labels,\n    const float smoothing,\n    const bool half_to_float);\n\nat::Tensor softmax_xentropy_backward_cuda(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing);\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<at::Tensor> softmax_xentropy_forward(\n    const at::Tensor &input,\n    const at::Tensor &labels,\n    const float smoothing,\n    const bool half_to_float) {\n    CHECK_CUDA(input);\n    CHECK_INPUT(labels);\n\n    return softmax_xentropy_cuda(input, labels, smoothing, half_to_float);\n}\n\nat::Tensor softmax_xentropy_backward(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing)  {\n    CHECK_CUDA(grad_loss);\n    CHECK_CUDA(logits);\n    CHECK_INPUT(max_log_sum_exp);\n    CHECK_INPUT(labels);\n\n    return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"forward\", &softmax_xentropy_forward, \"Softmax cross entropy loss with label smoothing forward (CUDA)\");\n    m.def(\"backward\", &softmax_xentropy_backward, \"Softmax cross entropy loss with label smoothing backward (CUDA)\");\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/csrc/xentropy/xentropy_kernel.cu",
    "content": "/**\n * From PyTorch:\n *\n * Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n * Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n * Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n * Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n * Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n *\n * From Caffe2:\n *\n * Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n *\n * All contributions by Facebook:\n * Copyright (c) 2016 Facebook Inc.\n *\n * All contributions by Google:\n * Copyright (c) 2015 Google Inc.\n * All rights reserved.\n *\n * All contributions by Yangqing Jia:\n * Copyright (c) 2015 Yangqing Jia\n * All rights reserved.\n *\n * All contributions from Caffe:\n * Copyright(c) 2013, 2014, 2015, the respective contributors\n * All rights reserved.\n *\n * All other contributions:\n * Copyright(c) 2015, 2016 the respective contributors\n * All rights reserved.\n *\n * Caffe2 uses a copyright model similar to Caffe: each contributor holds\n * copyright over their contributions to Caffe2. The project versioning records\n * all such contribution and copyright details. If a contributor wants to further\n * mark their specific copyright on a particular contribution, they should\n * indicate their copyright solely in the commit message of the change when it is\n * committed.\n *\n * All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n *\n * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n *    and IDIAP Research Institute nor the names of its contributors may be\n *    used to endorse or promote products derived from this software without\n *    specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n * POSSIBILITY OF SUCH DAMAGE.\n */\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/NumericLimits.cuh>\n\n#include <THC/THC.h>\n#include <THC/THCGeneral.h>\n#include <THC/THCThrustAllocator.cuh>\n\n#include \"type_shim.h\"\n#include \"compat.h\"\n\n#define ALIGN_BYTES 16\n\nusing Tensor = at::Tensor;\nusing TensorList = at::TensorList;\nusing ScalarType = at::ScalarType;\nusing at::acc_type;\n\ntemplate<typename T, typename AccumT, typename OutT>\nstruct LogSoftMaxForwardEpilogue {\n  __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)\n    : logsum(max_input + std::log(sum)) {}\n\n  __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)\n    : logsum(max_log_sum_exp) {}\n\n  __device__ __forceinline__ OutT operator()(T input) const {\n    return static_cast<OutT>(input - logsum);\n  }\n\n  const AccumT logsum;\n};\n\ntemplate<typename T, typename AccumT, typename OutT>\nstruct LogSoftMaxBackwardEpilogue {\n  __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)\n    : sum(sum) {}\n\n  __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {\n    return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);\n  }\n\n  const AccumT sum;\n};\n\n\n\nconst int max_threads = 1024;\n\ninline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {\n  uint64_t block_size = 1;\n  uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));\n  while (block_size < (max_block_size/2)) block_size *= 2;\n  // Launch at least a single warp - the kernel assumes that.\n  block_size = std::max(block_size, static_cast<uint64_t>(32));\n  return dim3(block_size);\n}\n\ntemplate<typename T>\nstruct Add {\n  __device__ __forceinline__ T operator()(T a, T b) const {\n    return a + b;\n  }\n};\n\ntemplate<typename T>\nstruct Max {\n  __device__ __forceinline__ T operator()(T a, T b) const {\n    return a < b ? b : a;\n  }\n};\n\n\n////////////////////////////////////////////////////////////////////////////////\n// Regular kernel (fast when dim_size is large; requires inner_size == 1)\n////////////////////////////////////////////////////////////////////////////////\n\n\ntemplate <typename T, typename AccumT>\nstruct MaxFloat\n{\n  __device__ __forceinline__ AccumT operator()(AccumT max, T v) const {\n    return ::max(max, (AccumT)v);\n  }\n};\n\ntemplate<typename T, typename AccumT>\nstruct AddFloat\n{\n  __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {\n    return sum + v;\n  }\n};\n\ntemplate<typename T, typename AccumT>\nstruct SumExpFloat\n{\n  __device__ __forceinline__ SumExpFloat(AccumT v)\n    : max_k(v) {}\n\n  __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {\n    return sum + std::exp(v - max_k);\n  }\n\n  const AccumT max_k;\n};\n\ntemplate <template<typename> class Reduction, typename AccumT>\n__device__ __forceinline__ AccumT\nblockReduce(AccumT* smem, AccumT val,\n            const Reduction<AccumT>& r,\n            AccumT defaultVal)\n{\n  // To avoid RaW races from chaining blockReduce calls together, we need a sync here\n  __syncthreads();\n\n  smem[threadIdx.x] = val;\n\n  __syncthreads();\n\n  AccumT warpVal = defaultVal;\n\n  // First warp will perform per-warp reductions for the remaining warps\n  uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;\n  if (threadIdx.x < 32) {\n    int lane = threadIdx.x % 32;\n    if (lane < blockDim.x / 32) {\n#pragma unroll\n      for (int i = 0; i < 32; ++i) {\n        warpVal = r(warpVal, smem[lane * 32 + i]);\n      }\n      __syncwarp(mask);\n      smem[lane] = warpVal;\n    }\n  }\n\n  __syncthreads();\n\n  // First thread will perform a reduction of the above per-warp reductions\n  AccumT blockVal = defaultVal;\n\n  if (threadIdx.x == 0) {\n    for (int i = 0; i < blockDim.x / 32; ++i) {\n      blockVal = r(blockVal, smem[i]);\n    }\n    smem[0] = blockVal;\n  }\n\n  // Sync and broadcast\n  __syncthreads();\n  return smem[0];\n}\n\ntemplate <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>\n__device__ __forceinline__ void\nblockReduce(AccumT* smem,\n            AccumT* reducVal1,\n            AccumT val1,\n            const Reduction1<AccumT>& r1,\n            AccumT defaultVal1,\n            AccumT* reducVal2,\n            AccumT val2,\n            const Reduction2<AccumT>& r2,\n            AccumT defaultVal2)\n{\n  // To avoid RaW races from chaining blockReduce calls together, we need a sync here\n  __syncthreads();\n\n  smem[threadIdx.x] = val1;\n  smem[blockDim.x + threadIdx.x] = val2;\n\n  __syncthreads();\n\n  AccumT warpVal1 = defaultVal1;\n  AccumT warpVal2 = defaultVal2;\n\n  // First warp will perform per-warp reductions for the remaining warps\n  uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;\n  if (threadIdx.x < 32) {\n    int lane = threadIdx.x % 32;\n    if (lane < blockDim.x / 32) {\n#pragma unroll\n      for (int i = 0; i < 32; ++i) {\n        warpVal1 = r1(warpVal1, smem[lane * 32 + i]);\n        warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);\n      }\n      __syncwarp(mask);\n      smem[lane] = warpVal1;\n      smem[lane + blockDim.x] = warpVal2;\n    }\n  }\n\n  __syncthreads();\n\n  // First thread will perform a reduction of the above per-warp reductions\n  AccumT blockVal1 = defaultVal1;\n  AccumT blockVal2 = defaultVal2;\n\n  if (threadIdx.x == 0) {\n    for (int i = 0; i < blockDim.x / 32; ++i) {\n      blockVal1 = r1(blockVal1, smem[i]);\n      blockVal2 = r2(blockVal2, smem[i + blockDim.x]);\n    }\n    smem[0] = blockVal1;\n    smem[blockDim.x] = blockVal2;\n  }\n\n  // Sync and broadcast\n  __syncthreads();\n  *reducVal1 = smem[0];\n  *reducVal2 = smem[blockDim.x];\n  __syncthreads();\n}\n\ntemplate <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>\n__device__ __forceinline__ AccumT\nilpReduce(int shift,\n          T* data,\n          int size,\n          const Reduction<T, AccumT>& r,\n          AccumT defaultVal)\n{\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;\n  AccumT threadVal = defaultVal;\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if(shift > 0){\n    data -= shift;\n    size += shift;\n    if(threadIdx.x >= shift){\n      threadVal = r(threadVal, data[offset]);\n    }\n    size -= blockDim.x;\n    data += blockDim.x;\n  }\n  int last = size % (ILP * blockDim.x);\n\n  T v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n\n  for (; offset * ILP < (size - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(data)[offset];\n\n    for (int j = 0; j < ILP; ++j) {\n      threadVal = r(threadVal, v[j]);\n    }\n  }\n\n  offset = size - last + threadIdx.x;\n  // Epilogue\n  for (; offset < size; offset += blockDim.x)\n    threadVal = r(threadVal, data[offset]);\n\n  return threadVal;\n}\n\ntemplate <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>\n__device__ __forceinline__ void\nilpReduce(int shift,\n          T* data,\n          int size,\n          AccumT* reducVal1,\n          const Reduction1<T, AccumT>& r1,\n          AccumT defaultVal1,\n          AccumT* reducVal2,\n          const Reduction2<T, AccumT>& r2,\n          AccumT defaultVal2)\n{\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;\n\n  AccumT threadVal1 = defaultVal1;\n  AccumT threadVal2 = defaultVal2;\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if(shift > 0){\n    data -= shift;\n    size += shift;\n    if(threadIdx.x >= shift){\n      threadVal1 = r1(threadVal1, data[offset]);\n      threadVal2 = r2(threadVal2, data[offset]);\n    }\n    size -= blockDim.x;\n    data += blockDim.x;\n  }\n  int last = size % (ILP * blockDim.x);\n\n  T v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n\n  for (; offset * ILP < (size - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(data)[offset];\n\n    for (int j = 0; j < ILP; ++j) {\n      threadVal1 = r1(threadVal1, v[j]);\n      threadVal2 = r2(threadVal2, v[j]);\n    }\n  }\n\n  offset = size - last + threadIdx.x;\n  // Epilogue\n  for (; offset < size; offset += blockDim.x) {\n    threadVal1 = r1(threadVal1, data[offset]);\n    threadVal2 = r2(threadVal2, data[offset]);\n  }\n\n  *reducVal1 = threadVal1;\n  *reducVal2 = threadVal2;\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>\n__global__ void\ncunn_SoftMaxXEntropyForward(\n    accscalar_t *losses,\n    outscalar_t *max_log_sum_exp,\n    scalar_t *input,\n    int64_t *labels,\n    int64_t classes,\n    const float smoothing)\n{\n  extern __shared__ unsigned char smem[];\n  auto sdata = reinterpret_cast<accscalar_t*>(smem);\n  // forward pointers to batch[blockIdx.x]\n  // each block handles a sample in the mini-batch\n  input += blockIdx.x * classes;\n  //output += blockIdx.x * classes;\n  const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);\n\n  int64_t label = labels[blockIdx.x];\n\n  // find the max and sum\n  accscalar_t threadMax, threadSum, max_k, sum_k;\n  ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(\n    shift, input, classes,\n    &threadMax, MaxFloat<scalar_t, accscalar_t>(),\n    -at::numeric_limits<accscalar_t>::max(),\n    &threadSum, AddFloat<scalar_t, accscalar_t>(),\n    static_cast<accscalar_t>(0));\n\n  blockReduce<Max, Add, accscalar_t>(\n      sdata,\n      &max_k, threadMax, Max<accscalar_t>(),\n      -at::numeric_limits<accscalar_t>::max(),\n      &sum_k, threadSum, Add<accscalar_t>(),\n      static_cast<accscalar_t>(0));\n\n  accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));\n  accscalar_t sumAll = blockReduce<Add, accscalar_t>(\n      sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));\n\n  Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);\n\n  // calculate per element loss with label smoothing\n  // reserve max + log_sum_exp for bprop\n  if (threadIdx.x == 0) {\n    accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));\n    losses[blockIdx.x] = (max_k + std::log(sumAll) - sum_k / classes) \\\n      * smoothing - log_prob * (1 - smoothing);\n    max_log_sum_exp[blockIdx.x] = max_k + std::log(sumAll);\n  }\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>\n__device__ __forceinline__ void\napply(scalar_t *gradInput,\n      scalar_t *logits,\n      outscalar_t *max_log_sum_exp,\n      outscalar_t *gradOutput,\n      int64_t *labels,\n      const float smoothing,\n      int classes)\n{\n  accscalar_t smooth_positives = 1.0 - smoothing;\n  accscalar_t smooth_negatives = smoothing / classes;\n  accscalar_t tmpGradOutput = gradOutput[blockIdx.x];\n  int64_t label = labels[blockIdx.x];\n  accscalar_t coeff = max_log_sum_exp[blockIdx.x];\n\n  int offset = threadIdx.x;\n  int last = classes % (ILP * blockDim.x);\n\n  for (; offset < classes - last; offset += blockDim.x * ILP) {\n    accscalar_t tmpLogits[ILP];\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j) {\n      tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);\n    }\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j)\n      gradInput[offset + j * blockDim.x] = tmpGradOutput * (\n        std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(\n          (offset + j * blockDim.x == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n  }\n\n  for (; offset < classes; offset += blockDim.x)\n    gradInput[offset] = tmpGradOutput * (std::exp(\n        static_cast<accscalar_t>(logits[offset]) - coeff) -\n        static_cast<accscalar_t>((offset == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n}\n\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>\n__device__ __forceinline__ void\naligned_apply(int shift,\n              scalar_t *gradInput,\n              scalar_t *logits,\n              outscalar_t *max_log_sum_exp,\n              outscalar_t *gradOutput,\n              int64_t *labels,\n              const float smoothing,\n              int classes)\n{\n  accscalar_t smooth_positives = 1.0 - smoothing;\n  accscalar_t smooth_negatives = smoothing / classes;\n  accscalar_t tmpGradOutput = gradOutput[blockIdx.x];\n  int64_t label = labels[blockIdx.x];\n  accscalar_t coeff = max_log_sum_exp[blockIdx.x];\n\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if(shift > 0){\n    logits -= shift;\n    gradInput -= shift;\n    classes += shift;\n    if(threadIdx.x >= shift){\n      gradInput[offset] = tmpGradOutput * (std::exp(\n        static_cast<accscalar_t>(logits[offset]) - coeff) -\n        static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n    }\n    classes -= blockDim.x;\n    gradInput += blockDim.x;\n    logits += blockDim.x;\n    shift -= blockDim.x;\n  }\n\n  int last = classes % (ILP * blockDim.x);\n\n  typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;\n  // input\n  scalar_t v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n  // output\n  scalar_t r[ILP];\n  LoadT* result = reinterpret_cast<LoadT*>(&r);\n\n  for (; offset * ILP < (classes - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(logits)[offset];\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j) {\n      r[j] = tmpGradOutput * (std::exp(\n          static_cast<accscalar_t>(v[j]) - coeff) -\n          static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *\n          smooth_positives - smooth_negatives);\n    }\n    reinterpret_cast<LoadT*>(gradInput)[offset] = *result;\n  }\n\n  offset = classes - last + threadIdx.x;\n  for (; offset < classes; offset += blockDim.x)\n    gradInput[offset] = tmpGradOutput * (std::exp(\n        static_cast<accscalar_t>(logits[offset]) - coeff) -\n        static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>\n__global__ void\ncunn_SoftMaxXEntropyBackward(\n    scalar_t *gradInput,\n    scalar_t *logits,\n    outscalar_t *max_log_sum_exp,\n    outscalar_t *gradOutput,\n    int64_t *labels,\n    const float smoothing,\n    int classes)\n{\n  gradInput += blockIdx.x * classes;\n  logits += blockIdx.x * classes;\n\n  // Do vectorized load/store when input/output have same alignment\n  const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);\n  const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);\n  if (shift == shift_){\n    aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);\n  }\n  else {\n    apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);\n  }\n\n}\n\ntemplate<template<typename, typename, typename> class Epilogue>\nstd::vector<Tensor> host_softmax_xentropy(\n        const Tensor & input_,\n        const Tensor & labels_,\n        const float smoothing,\n        const bool half_to_float){\n  if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half,\"conversion is supported for Half type only\");\n  AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,\"Label type should be CUDA Long\");\n\n  auto input = input_.contiguous();\n  Tensor max_log_sum_exp = at::empty_like(labels_, half_to_float ? input.options().dtype(ScalarType::Float) : input.options());\n  Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));\n\n  static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||\n    std::is_same<acc_type<at::Half, true>, double>::value,\n    \"accscalar_t for half should be float or double\");\n  AT_ASSERTM(input.dim() == 2, \"Currently only 2 dim input supported\");\n  AT_ASSERTM(labels_.dim() == 1, \"Labels should be 1 dimensional\");\n  AT_ASSERTM(input.size(0) == labels_.size(0), \"Input and label should have same number of examples\");\n  AT_ASSERTM(input.numel() > 0, \"Number of classes in input should not be 0\");\n\n  const int64_t dim = 1;\n  int64_t outer_size = 1;\n  int64_t dim_size = input.size(dim);\n  int64_t inner_size = 1;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  for (int64_t i = 0; i < dim; ++i)\n    outer_size *= input.size(i);\n  for (int64_t i = dim + 1; i < input.dim(); ++i)\n    inner_size *= input.size(i);\n  // This kernel spawns a block per each element in the batch.\n  // XXX: it assumes that inner_size == 1\n  TORCH_CHECK(inner_size == 1, \"Currently only inner size 1 supported\");\n\n  dim3 grid(outer_size);\n\n  using namespace at;\n  DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, \"host_softmax_xentropy\",\n    using accscalar_t = at::acc_type<scalar_t_0, true>;\n    const int ILP = sizeof(float4)/sizeof(scalar_t_0);\n    dim3 block = SoftMax_getBlockSize(ILP, dim_size);\n    if (!half_to_float) {\n      cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>\n        <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(\n          losses.DATA_PTR<accscalar_t>(), max_log_sum_exp.DATA_PTR<scalar_t_0>(),\n          input.DATA_PTR<scalar_t_0>(), labels_.DATA_PTR<int64_t>(),\n          dim_size, smoothing\n      );\n    } else {\n      cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>\n        <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(\n          losses.DATA_PTR<accscalar_t>(), max_log_sum_exp.DATA_PTR<accscalar_t>(),\n          input.DATA_PTR<scalar_t_0>(), labels_.DATA_PTR<int64_t>(),\n          dim_size, smoothing\n      );\n    }\n  );\n\n  THCudaCheck(cudaGetLastError());\n\n  std::vector<at::Tensor> ret = {losses, max_log_sum_exp};\n  return ret;\n}\n\ntemplate<template<typename, typename, typename> class Epilogue>\nTensor host_softmax_xentropy_backward(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits_,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing,\n    bool half_to_float) {\n  const int64_t dim = 1;\n  Tensor gI = at::empty_like(logits_);\n  if (grad_loss.numel() == 0) {\n    return gI;\n  }\n\n  auto grad = grad_loss.contiguous();\n  auto logits = logits_.contiguous();\n\n  static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||\n    std::is_same<acc_type<at::Half, true>, double>::value,\n    \"accscalar_t for half should be float or double\");\n  if (grad.dim() == 0) grad = grad.view(1);\n\n  AT_ASSERTM(logits_.dim() == 2, \"Currently only 2 dim input supported\");\n  AT_ASSERTM(labels.dim() == 1, \"Labels should be 1 dimensional\");\n  AT_ASSERTM(logits_.numel() > 0, \"Number of classes in input should not be 0\");\n  AT_ASSERTM(logits_.size(0) == labels.size(0), \"Input and label should have same number of examples\");\n  AT_ASSERTM(labels.size(0) == grad.size(0), \"Label and loss should have same number of examples\");\n\n  int64_t outer_size = 1;\n  int64_t dim_size = logits.size(dim);\n  int64_t inner_size = 1;\n  for (int64_t i = 0; i < dim; ++i)\n    outer_size *= logits.size(i);\n  for (int64_t i = dim + 1; i < logits.dim(); ++i)\n    inner_size *= logits.size(i);\n  // See descriptions of kernels above.\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  TORCH_CHECK(inner_size == 1, \"Currently only inner size 1 supported\");\n\n  dim3 grid(outer_size);\n\n  DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, \"host_softmax_xentropy_backward\",\n    using accscalar_t = acc_type<scalar_t_0, true>;\n    const int ILP = sizeof(float4)/sizeof(scalar_t_0);\n    dim3 block = SoftMax_getBlockSize(ILP, dim_size);\n    if (!half_to_float) {\n      cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>\n       <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n          gI.DATA_PTR<scalar_t_0>(), logits.DATA_PTR<scalar_t_0>(),\n          max_log_sum_exp.DATA_PTR<scalar_t_0>(),\n          grad.DATA_PTR<scalar_t_0>(), labels.DATA_PTR<int64_t>(),\n          smoothing, dim_size\n      );\n    } else {\n      cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>\n       <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n          gI.DATA_PTR<scalar_t_0>(), logits.DATA_PTR<scalar_t_0>(),\n          max_log_sum_exp.DATA_PTR<accscalar_t>(),\n          grad.DATA_PTR<accscalar_t>(), labels.DATA_PTR<int64_t>(),\n          smoothing, dim_size\n      );\n    }\n  );\n\n  THCudaCheck(cudaGetLastError());\n  return gI;\n}\n\nstd::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const bool half_to_float){\n  return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, half_to_float);\n}\n\nat::Tensor softmax_xentropy_backward_cuda(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing) {\n  bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType();\n  if (half_to_float) {\n     AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && logits.type().scalarType() == ScalarType::Half), \"expected input and grad types to match, or input to be at::Half and grad to be at::Float\");\n  }\n  return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float);\n}\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/examples/multihead_attn/func_test_multihead_attn.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nparser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')\nparser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')\nparser.add_argument('--num-seqs-start', default=5, type=int, help='Start Range of Number of Sequences')\nparser.add_argument('--num-seqs-stop', default=80, type=int, help='Stop Range of Number of Sequences')\nparser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')\nparser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')\nparser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')\nparser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--seed-start', default=1, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--seed-end', default=100, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')\nparser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')\nparser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')\nparser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')\nparser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')\nparser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')\nparser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')\nparser.add_argument('--eval', action='store_true', help='Inference only, no backward pass.')\n\nargs = parser.parse_args()\nassert args.seq_length % 64 == 0, \"Sequence Length should be a multiple of 64!\"\n\nif not torch.cuda.is_available():\n    raise NotImplementedError('Running on CPU is not supported')\ntorch.cuda.set_device(0)\n\ndropout_prob = 0.1\n\nfor seed in range(args.seed_start, args.seed_end+1) :\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    ref_layer = None\n    if args.encdec_attn :\n        ref_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')\n    else :\n        ref_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')\n    ref_layer.cuda()\n    ref_layer.half()\n    ref_layer.reset_parameters()\n\n    ref_inputs    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    ref_inputs_kv = None\n    if args.encdec_attn :\n        ref_inputs_kv    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    ref_grads         = torch.randn_like(ref_inputs)\n\n    ref_outputs,_ = ref_layer.forward(ref_inputs,\n                                      ref_inputs_kv,\n                                      ref_inputs_kv,\n                                      key_padding_mask=None,\n                                      need_weights=False,\n                                      attn_mask=None,\n                                      is_training=(not args.eval))\n\n    ref_outputs.backward(ref_grads)\n\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    tst_layer = None\n    if args.encdec_attn :\n        tst_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')\n    else:\n        tst_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')\n    tst_layer.cuda()\n    tst_layer.half()\n    tst_layer.reset_parameters()\n\n    tst_inputs    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    tst_inputs_kv = None\n    if args.encdec_attn :\n        tst_inputs_kv    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    assert torch.equal(ref_inputs,tst_inputs), \"ERROR: Inputs are different!\"\n\n    tst_grads         = torch.randn_like(tst_inputs)\n\n    tst_outputs,_ = tst_layer.forward(tst_inputs,\n                                      tst_inputs_kv,\n                                      tst_inputs_kv,\n                                      key_padding_mask=None,\n                                      need_weights=False,\n                                      attn_mask=None,\n                                      is_training=(not args.eval))\n\n    tst_outputs.backward(tst_grads)\n\n    fwd_close = torch.equal(ref_outputs, tst_outputs)\n    bwd_close = torch.equal(ref_inputs.grad, tst_inputs.grad)\n\n    diff_fwd = ref_outputs - tst_outputs\n    diff_cnt_fwd = diff_fwd.ne(0.0).sum()\n    diff_accum_fwd = diff_fwd.abs().sum()\n\n    diff_bwd = ref_inputs.grad - tst_inputs.grad\n    diff_cnt_bwd = diff_bwd.ne(0.0).sum()\n    diff_accum_bwd = diff_bwd.abs().sum()\n\n    print(\">>> Seed: \", seed, fwd_close, diff_cnt_fwd.item(), diff_accum_fwd.item(), bwd_close, diff_cnt_bwd.item(), diff_accum_bwd.item())\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nparser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')\nparser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')\nparser.add_argument('--num-seqs-start', default=10, type=int, help='Start Range of Number of Sequences')\nparser.add_argument('--num-seqs-stop', default=120, type=int, help='Stop Range of Number of Sequences')\nparser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')\nparser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')\nparser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')\nparser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')\nparser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')\nparser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')\nparser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')\nparser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')\nparser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')\nparser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')\nparser.add_argument('--biases', action='store_true', help='Execute multihead attention with Linear Biases.')\n\nargs = parser.parse_args()\n\nif not torch.cuda.is_available():\n    raise NotImplementedError('Running on CPU is not supported')\ntorch.cuda.set_device(0)\n\ntorch.manual_seed(111)\nif torch.cuda.is_available():\n    torch.cuda.manual_seed_all(111)\n\nattn_layers = []\nfor idx in range(0, args.layers) :\n    if args.encdec_attn :\n        if args.ref :\n            attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=False, impl='default'))\n        else :\n            attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))\n    else :\n        if args.native :\n            attn_layers.append(torch.nn.MultiheadAttention(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases))\n        elif args.ref :\n            attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='default'))\n        else :\n            attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))\n    attn_layers[idx].cuda()\n    attn_layers[idx].half()\n    if not args.native :\n        attn_layers[idx].reset_parameters()\n\nstart_evt_fwd = []\nstart_evt_bwd = []\nstop_evt_bwd  = []\nfor recorded_trial in range(0, args.trials) :\n    start_evt_fwd.append(torch.cuda.Event(enable_timing=True))\n    start_evt_bwd.append(torch.cuda.Event(enable_timing=True))\n    stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))\n\nfor sequences in range(args.num_seqs_start, args.num_seqs_stop + args.num_seqs_inc, args.num_seqs_inc) :\n    inputs        = torch.randn(args.seq_length, sequences, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    grads         = torch.randn_like(inputs)\n   \n    for trial in range(0, args.trials + args.warmup_trials) :\n        layer_inputs  = inputs\n        evt_idx       = trial - args.warmup_trials\n    \n        if evt_idx >= 0 :\n            start_evt_fwd[evt_idx].record()\n    \n        for lyr_idx in range(0, args.layers) :\n            if args.native :\n                outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, \n                                                         layer_inputs, \n                                                         layer_inputs, \n                                                         key_padding_mask=None, \n                                                         need_weights=False, \n                                                         attn_mask=None)\n            else :\n                outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, \n                                                         layer_inputs, \n                                                         layer_inputs,\n                                                         key_padding_mask=None, \n                                                         need_weights=False, \n                                                         attn_mask=None,\n                                                         is_training=True)\n            layer_inputs = outputs\n    \n        if evt_idx >= 0 :\n            start_evt_bwd[evt_idx].record()\n\n        if not args.fwd :\n            layer_inputs.backward(grads)\n    \n        if evt_idx >= 0 :\n            stop_evt_bwd[evt_idx].record()\n   \n    torch.cuda.synchronize()\n    elapsed_time_fwd = 0.0\n    elapsed_time_bwd = 0.0\n    for evt_idx in range(0, args.trials) :\n        elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])\n        elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])\n   \n    print(\"[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms\".format(\n        'Encdec' if args.encdec_attn else 'Self',              \\\n        'Norm&Add' if args.norm_add else '',                   \\\n        sequences*args.seq_length,                             \\\n        sequences,                                             \\\n        args.seq_length,                                       \\\n        elapsed_time_fwd / ( args.trials * args.layers ),      \\\n        elapsed_time_bwd / ( args.trials * args.layers )))\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/fmha/__init__.py",
    "content": "from .fmha import FMHAFun\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/fmha/fmha.py",
    "content": "###############################################################################\n# Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n# \n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#     * Redistributions of source code must retain the above copyright\n#       notice, this list of conditions and the following disclaimer.\n#     * Redistributions in binary form must reproduce the above copyright\n#       notice, this list of conditions and the following disclaimer in the\n#       documentation and/or other materials provided with the distribution.\n#     * Neither the name of the NVIDIA CORPORATION nor the\n#       names of its contributors may be used to endorse or promote products\n#       derived from this software without specific prior written permission.\n# \n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n#\n###############################################################################\n\n\nimport torch\nimport torch.nn.functional as F\nimport fmhalib as mha\n\nclass FMHAFun(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):\n        batch_size = cu_seqlens.numel() - 1\n        if batch_size < 4:\n            context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)\n        else:\n            context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)\n        ctx.save_for_backward(qkv, S_dmask)\n        ctx.cu_seqlens = cu_seqlens\n        ctx.p_dropout = p_dropout\n        ctx.max_s = max_s\n        return context\n    \n    @staticmethod\n    def backward(ctx, dout):\n        qkv, S_dmask = ctx.saved_tensors\n        batch_size = ctx.cu_seqlens.numel() - 1\n        if batch_size < 4:\n            dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)\n        else:\n            dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)\n\n        return dqkv, None, None, None, None, None, None\n\nclass FMHA(torch.nn.Module):\n\n    def __init__(self, config):\n\n        super(FMHA, self).__init__()\n\n        self.p_dropout = config.attention_probs_dropout_prob\n        self.h = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.d = self.hidden_size // self.h\n        assert self.d * self.h == self.hidden_size, \"Invalid hidden size/num_heads\"\n\n    def forward(self, qkv, cu_seqlens, max_s, is_training=True):\n\n        ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)\n\n        return ctx.view(-1, self.hidden_size)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/groupbn/__init__.py",
    "content": "try:\n    import torch\n    import bnp\n    from .batch_norm import BatchNorm2d_NHWC\n    del torch\n    del bnp\n    del batch_norm\nexcept ImportError as err:\n    print(\"apex was installed without --bnp flag, contrib.groupbn is not available\")\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/groupbn/batch_norm.py",
    "content": "import torch\nimport numpy as np\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nimport bnp\n\nclass bn_NHWC_impl(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):\n        if is_train:\n            ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)\n            ctx.epsilon = epsilon\n            ctx.momentum = mom\n            ctx.ret_cta = ret_cta\n            ctx.fuse_relu = fuse_relu\n            ctx.my_data = my_data\n            ctx.pair_data = pair_data\n            ctx.magic = magic\n            ctx.pair_data2 = pair_data2\n            ctx.pair_data3 = pair_data3\n            ctx.bn_group = bn_group\n            ctx.bwd_occup = bwd_occup\n            ctx.bwd_grid_x = bwd_grid_x\n            ctx.multi_stream = multi_stream\n\n            res =  bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)\n            return res\n        else:\n            return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu)\n\n    @staticmethod\n    def backward(ctx, grad_y):\n        x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables\n        epsilon = ctx.epsilon\n        mom = ctx.momentum\n        ret_cta = ctx.ret_cta\n        fuse_relu = ctx.fuse_relu\n        my_data = ctx.my_data\n        pair_data = ctx.pair_data\n        magic = ctx.magic\n        pair_data2 = ctx.pair_data2\n        pair_data3 = ctx.pair_data3\n        bn_group = ctx.bn_group\n        bwd_occup = ctx.bwd_occup\n        bwd_grid_x = ctx.bwd_grid_x\n        multi_stream = ctx.multi_stream\n\n        dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)\n\n        return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\nclass bn_addrelu_NHWC_impl(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):\n        if is_train:\n            bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)\n            ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)\n            ctx.epsilon = epsilon\n            ctx.momentum = mom\n            ctx.ret_cta = ret_cta\n            ctx.my_data = my_data\n            ctx.pair_data = pair_data\n            ctx.magic = magic\n            ctx.pair_data2 = pair_data2\n            ctx.pair_data3 = pair_data3\n            ctx.bn_group = bn_group\n            ctx.bwd_occup = bwd_occup\n            ctx.bwd_grid_x = bwd_grid_x\n            ctx.multi_stream = multi_stream\n\n            res =  bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)\n            return res\n        else:\n            return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon)\n\n    @staticmethod\n    def backward(ctx, grad_y):\n        x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables\n        epsilon = ctx.epsilon\n        mom = ctx.momentum\n        ret_cta = ctx.ret_cta\n        my_data = ctx.my_data\n        pair_data = ctx.pair_data\n        magic = ctx.magic\n        pair_data2 = ctx.pair_data2\n        pair_data3 = ctx.pair_data3\n        bn_group = ctx.bn_group\n        bwd_occup = ctx.bwd_occup\n        bwd_grid_x = ctx.bwd_grid_x\n        multi_stream = ctx.multi_stream\n\n        dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)\n\n        return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\n\n\n\nclass BatchNorm2d_NHWC(_BatchNorm):\n    # if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True\n    def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False):\n        super(BatchNorm2d_NHWC, self).__init__(num_features)\n\n        self.fuse_relu = fuse_relu\n        self.multi_stream = multi_stream\n\n        self.minibatch_mean = torch.cuda.FloatTensor(num_features)\n        self.minibatch_riv = torch.cuda.FloatTensor(num_features)\n\n        #defaut to distributed bn disabled\n        self.bn_group = bn_group\n        self.max_cta_per_sm = max_cta_per_sm        #used only in training fwd and bwd\n        self.cta_launch_margin = cta_launch_margin  #used only in training fwd and bwd\n        self.my_data = None\n        self.pair_data = None\n        self.pair_data2 = None\n        self.pair_data3 = None\n        self.local_rank = 0\n        self.magic = torch.IntTensor([0])\n\n        #calculate cta per sm occupancies\n        assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :)\n        self.fwd_occupancy =  min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm)\n        self.bwd_occupancy =  min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm)\n        self.addrelu_fwd_occupancy =  min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm)\n        self.addrelu_bwd_occupancy =  min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm)\n\n        #calculate grid dimentions based on occupancy numbers\n        mp_count = torch.cuda.get_device_properties(None).multi_processor_count\n        self.fwd_grid_dim_x = max(mp_count*self.fwd_occupancy - cta_launch_margin , 1)\n        self.bwd_grid_dim_x = max(mp_count*self.bwd_occupancy - cta_launch_margin , 1)\n        self.addrelu_fwd_grid_dim_x = max(mp_count*self.addrelu_fwd_occupancy - cta_launch_margin , 1)\n        self.addrelu_bwd_grid_dim_x = max(mp_count*self.addrelu_bwd_occupancy - cta_launch_margin , 1)\n        self.grid_dim_y = (num_features + 63) // 64\n\n        # allocate scratch space used by implementation\n        # TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the\n        # same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new\n        # buffer from cache allocator to avoid unnecessary initialization at future iterations.\n        self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0)\n\n        #FIXME: turn pair handles into an array\n        if bn_group>1:\n            local_rank = torch.distributed.get_rank()\n            world_size = torch.distributed.get_world_size()          \n            assert(world_size >= bn_group)\n            assert(world_size % bn_group == 0)\n         \n            bn_sync_steps = 1\n            if (bn_group==4):\n                bn_sync_steps = 2\n            if (bn_group==8):\n                bn_sync_steps = 3\n\n            self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))\n            self.my_data = bnp.get_data_ptr(self.ipc_buffer)\n            # we are walking on very thin ice here by utilizing internal `_share_cuda_()`\n            self.storage = self.ipc_buffer.storage()\n            self.share_cuda = self.storage._share_cuda_()\n            internal_cuda_mem = self.share_cuda\n            # internal_cuda_mem[1]: ipc_mem_handle\n            my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))\n            # internal_cuda_mem[3]: offset\n            my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])\n\n            handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device)\n            handles_l = list(handles_all.unbind(0))\n            torch.distributed.all_gather(handles_l, my_handle)\n\n            offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device)\n            offsets_l = list(offsets_all.unbind(0))\n            torch.distributed.all_gather(offsets_l, my_offset)\n\n            #whom do I actually care about? that would be local_rank XOR 1\n            self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()\n            pair_offset = offsets_l[local_rank ^ 1].cpu()\n            self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)\n\n            if bn_group>2:\n                self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()\n                pair_offset2 = offsets_l[local_rank ^ 2].cpu()\n                self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)\n\n            if bn_group>4:\n                self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()\n                pair_offset3 = offsets_l[local_rank ^ 4].cpu()\n                self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)\n\n            #FIXME: get magic value into C code and eliminate from here\n            self.magic = torch.IntTensor([2])\n            self.local_rank = local_rank\n\n\n    def forward(self, x, z=None):\n        if z is not None:\n            assert(self.fuse_relu==True)\n            return bn_addrelu_NHWC_impl.apply(x, z,\n                                  self.weight, self.bias,\n                                  self.running_mean, self.running_var,\n                                  self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta,\n                                  self.momentum,\n                                  self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,\n                                  self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x,\n                                  self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x,\n                                  self.multi_stream)\n        else:\n            return bn_NHWC_impl.apply(x,\n                                  self.weight, self.bias,\n                                  self.running_mean, self.running_var,\n                                  self.minibatch_mean, self.minibatch_riv, self.ret_cta,\n                                  self.momentum,\n                                  self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,\n                                  self.fwd_occupancy, self.fwd_grid_dim_x,\n                                  self.bwd_occupancy, self.bwd_grid_dim_x,\n                                  self.multi_stream)\n\n    def __del__(self):\n        if self.bn_group>1:\n          bnp.close_remote_data(self.pair_handle)\n          if self.bn_group>2:\n              bnp.close_remote_data(self.pair_handle2)\n              if self.bn_group>4:\n                 bnp.close_remote_data(self.pair_handle3)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/layer_norm/__init__.py",
    "content": "from .layer_norm import FastLayerNorm\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/layer_norm/layer_norm.py",
    "content": "import torch\nfrom torch.nn import init\n\nimport fast_layer_norm\n\nclass FastLayerNormFN(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, gamma, beta, epsilon):\n        x = x.contiguous()\n        gamma = gamma.contiguous()\n        beta = beta.contiguous()\n        hidden_size = gamma.numel()\n        xmat = x.view((-1, hidden_size))\n        ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)\n        ctx.save_for_backward(x, gamma, mu, rsigma)\n        return ymat.view(x.shape)\n    \n    @staticmethod\n    def backward(ctx, dy):\n        #assert dy.is_contiguous()\n        dy = dy.contiguous() # this happens!\n        x, gamma, mu, rsigma = ctx.saved_tensors\n\n        hidden_size = gamma.numel()\n        xmat = x.view((-1, hidden_size))\n        dymat = dy.view(xmat.shape)\n        dxmat, dgamma, dbeta = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)\n        dx = dxmat.view(x.shape)\n        return dx, dgamma, dbeta, None\n\nclass FastLayerNorm(torch.nn.Module):\n    def __init__(self, hidden_size, eps=1e-5):\n        super(FastLayerNorm, self).__init__()\n        self.epsilon = eps\n        self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))\n        self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init.ones_(self.weight)\n        init.zeros_(self.bias)\n\n    def forward(self, x):\n        return FastLayerNormFN.apply(x, self.weight, self.bias, self.epsilon)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/README.md",
    "content": "# Fast Multihead Attention \n\nThis implementation has two main features :\n* A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes.\n* The removal of all copies and transposes found in standard implementations of Multihead Attention.\n\n|                                            | Python Version | C++ Version |\n| :----------------------------------------- | :------------: | :---------: |\n| Layer Norm and Residual Add Variant        | X              | X           |\n| Includes Linear Biases                     | X              |             |\n| Reduces CPU Overheads                      |                | X           |\n| Fuses masking with Softmax                 |                | X           |\n| Removes Transposes and Copies              | X              | X           |\n| Includes Self and Encoder/Decoder Variants | X              | X           |\n\n## How to Instantiate\n\n`SelfMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`\n`EncdecMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`\n\n `impl` has two options:\n * `fast` uses C++ Version\n * `default` uses Python Version\n\n## Instructions to build on Linux\n\n```\n$ git clone https://github.com/NVIDIA/apex\n$ cd apex\n$ pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--fast_multihead_attn\" ./\n```\n## Try Performance Tests Yourself!\nPerf test script is found here!\n```\ncd contrib/examples/multihead_attn\n```\n#### Fast Multihead Attention\n```\npython perf_test_multihead_attn.py --ref\n```\n#### Fast Multihead Attention with C++ Implementation\n```\npython perf_test_multihead_attn.py\n```\n#### Compare with `torch.nn.MultiheadAttn`\n```\npython perf_test_multihead_attn.py --native\n```\n#### Test your own range!\n```\npython perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5\n```\n\n## Performance Comparisons\n\n* Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card.\n* Time is measured across multiple layers to simulate an in model scenario.\n\n![Multihead Attention Forward](MHA_fwd.png)\n![Multihead Attention Backward](MHA_bwd.png)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/__init__.py",
    "content": "from .self_multihead_attn import SelfMultiheadAttn\nfrom .encdec_multihead_attn import EncdecMultiheadAttn\nfrom .mask_softmax_dropout_func import fast_mask_softmax_dropout_func\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/encdec_multihead_attn.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom .encdec_multihead_attn_func               import encdec_attn_func\nfrom .fast_encdec_multihead_attn_func          import fast_encdec_attn_func\nfrom .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func\nfrom apex.normalization.fused_layer_norm       import FusedLayerNorm\n\nif hasattr(torch._C, '_jit_set_profiling_executor') :\n    torch._C._jit_set_profiling_executor(False)\nif hasattr(torch._C, '_jit_set_profiling_mode') :\n    torch._C._jit_set_profiling_mode(False)\n\n@torch.jit.script\ndef jit_dropout_add(x, residual, prob, is_training):\n    # type: (Tensor, Tensor, float, bool) -> Tensor\n    out = F.dropout(x, p=prob, training=True)\n    out = residual + out\n    return out\n\n\nclass EncdecMultiheadAttn(nn.Module):\n    \"\"\"Multi-headed attention.\n\n    See \"Attention Is All You Need\" for more details.\n    \"\"\"\n    def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n        self.bias = bias\n        self.include_norm_add = include_norm_add\n        self.impl = impl\n        self.scaling = self.head_dim**-0.5\n\n        self.in_proj_weight_q    = Parameter(torch.Tensor(embed_dim, embed_dim))\n        self.in_proj_weight_kv   = Parameter(torch.Tensor(2*embed_dim, embed_dim))\n        self.out_proj_weight     = Parameter(torch.Tensor(embed_dim, embed_dim))\n        if self.bias:\n            assert impl != 'fast', \"ERROR! The Fast implementation does not support biases!\"\n            self.in_proj_bias_q  = Parameter(torch.Tensor(embed_dim))\n            self.in_proj_bias_kv = Parameter(torch.Tensor(2*embed_dim))\n            self.out_proj_bias   = Parameter(torch.Tensor(embed_dim))\n        else:\n            self.register_parameter('in_proj_bias_q', None)\n            self.register_parameter('in_proj_bias_kv', None)\n            self.in_proj_bias_q  = None\n            self.in_proj_bias_kv = None\n            self.out_proj_bias   = None\n        if self.include_norm_add:\n            if impl == 'fast':\n                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm_beta_weights  = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm               = None\n            else:\n                self.register_parameter('lyr_norm_gamma_weights', None)\n                self.register_parameter('lyr_norm_beta_weights', None)\n                self.lyr_nrm_gamma_weights = None\n                self.lyr_nrm_beta_weights  = None\n                self.lyr_nrm = FusedLayerNorm(embed_dim)\n        self.reset_parameters()\n\n        if self.include_norm_add:\n            if   impl == 'fast'    : self.attn_func = fast_encdec_attn_norm_add_func\n            elif impl == 'default' : self.attn_func = encdec_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n        else:\n            if   impl == 'fast'    : self.attn_func = fast_encdec_attn_func\n            elif impl == 'default' : self.attn_func = encdec_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n\n    def reset_parameters(self):\n        nn.init.xavier_uniform_(self.in_proj_weight_q)\n        # in_proj_weight_kv has shape [2 * hidden, hidden] but it should be\n        # initialized like a [hidden, hidden] matrix.\n        # sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)\n        # therefore xavier_uniform gain should be set to sqrt(1.5).\n        nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))\n        nn.init.xavier_uniform_(self.out_proj_weight)\n        if self.bias:\n            nn.init.constant_(self.in_proj_bias_q, 0.)\n            nn.init.constant_(self.in_proj_bias_kv, 0.)\n            nn.init.constant_(self.out_proj_bias, 0.)\n        if self.include_norm_add:\n            if self.impl == 'fast' :\n                nn.init.ones_(self.lyr_nrm_gamma_weights)\n                nn.init.zeros_(self.lyr_nrm_beta_weights)\n            else:\n                self.lyr_nrm.reset_parameters()\n\n    def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):\n        \"\"\"Input shape: Time x Batch x Channel\n\n        Self-attention can be implemented by passing in the same arguments for\n        query, key and value. Future timesteps can be masked with the\n        `mask_future_timesteps` argument. Padding elements can be excluded from\n        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:\n        batch x src_len, where padding elements are indicated by 1s.\n        \"\"\"\n\n        if key_padding_mask is not None:\n            assert (attn_mask is None), \"ERROR attn_mask and key_padding_mask should not be both defined!\"\n            mask = key_padding_mask\n        elif attn_mask is not None:\n            mask = attn_mask\n        else:\n            mask = None\n\n        if self.include_norm_add:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,\n                                         self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)\n            else:\n                lyr_nrm_results = self.lyr_nrm(query)\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,\n                                         self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,\n                                         mask, self.dropout)\n                if is_training:\n                    outputs = jit_dropout_add(outputs, query, self.dropout, is_training)\n                else:\n                    outputs = outputs + query\n        else:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)\n            else:\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,\n                                         self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,\n                                         mask, self.dropout)\n\n        return outputs,None\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/encdec_multihead_attn_func.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\nclass EncdecAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,\n                input_weights_q, input_weights_kv, output_weights,\n                input_biases_q, input_biases_kv, output_biases,\n                mask, dropout_prob):\n        use_biases_t   = torch.tensor([input_biases_q is not None])\n        heads_t        = torch.tensor([heads])\n        scale_t        = torch.tensor([scale])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        head_dim       = inputs_q.size(2) // heads\n\n        # Input Linear GEMM Q\n        # input1: (activations) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim (1024), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        if use_biases_t[0]:\n            input_lin_q_results = torch.addmm(input_biases_q,\n                                              inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                                              input_weights_q.transpose(0,1),\n                                              beta=1., alpha=1.)\n        else:\n            input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))\n        input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))\n        # Input Linear GEMM KV\n        # input1: (activations) [seql_k, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_k, seqs, embed_dim*2]\n        # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)\n        if use_biases_t[0]:\n            input_lin_kv_results = torch.addmm(input_biases_kv,\n                                               inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),\n                                               input_weights_kv.transpose(0,1),\n                                               beta=1., alpha=1.)\n        else:\n            input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))\n        input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))\n\n        # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]\n        # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]\n        queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads, head_dim)\n        input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)\n        keys    = input_lin_kv_results[:,:,0,:]\n        values  = input_lin_kv_results[:,:,1,:]\n\n        # Matmul1 Batched GEMMs\n        # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification\n        # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of \n        # a separate elementwise operation.\n        # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)\n        # Input2: (Keys)    [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # output:           [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))\n        matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])\n\n        if mask is not None:\n            # Self Attention Time Mask\n            if use_time_mask:\n                assert (len(mask.size()) == 2), \"Timing mask is not 2D!\"\n                assert (mask.size(0) == mask.size(1)), \"Sequence length should match!\"\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))\n            # Key Padding Mask\n            else:\n                batches,seql_q,seql_k = matmul1_results.size()\n                seqs = int(batches / heads)\n                matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))\n                matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)\n\n        softmax_results = F.softmax(matmul1_results, dim=-1)\n\n        # Dropout - is not executed for inference\n        if is_training:\n            dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))\n        else:\n            dropout_results = softmax_results\n            dropout_mask    = null_tensor\n\n        # Matmul2 Batched GEMMs\n        # The output tensor specification is needed here to specify the non-standard output.\n        # Given that pytorch cannot currently perform autograd with an output tensor specified,\n        # this requires a backward pass specified.\n        # Input1: from_softmax [seqs*heads, seql_q, seql_k]\n        # Input2: (values)     [seql_v, seqs*heads, head_dim] transpose(0,1)\n        # Output:              [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)\n        matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)\n        matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)\n        matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))\n\n        # Output Linear GEMM\n        # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        if use_biases_t[0]:\n            outputs = torch.addmm(output_biases,\n                                  matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                                  output_weights.transpose(0,1),\n                                  beta=1., alpha=1.)\n        else:\n            outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))\n        outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))\n\n        ctx.save_for_backward(use_biases_t,                             \\\n                              heads_t,                                  \\\n                              scale_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n    \n    @staticmethod\n    def backward(ctx, output_grads):\n        use_biases_t,                                                   \\\n        heads_t,                                                        \\\n        scale_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        inputs_q,                                                       \\\n        inputs_kv,                                                      \\\n        input_weights_q,                                                \\\n        input_weights_kv,                                               \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t          = ctx.saved_tensors\n\n        head_dim                = inputs_q.size(2) // heads_t[0]\n\n        # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]\n        # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]\n        queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads_t[0], head_dim)\n        input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)\n        keys    = input_lin_kv_results[:,:,0,:]\n        values  = input_lin_kv_results[:,:,1,:]\n\n        # Slice out k,v from one big set of gradients entering the input linear's bprop  (should only impact meta data, no copies!)\n        # The gradients are identical in size to the Input Linear outputs.\n        # The tensor is declared before hand to properly slice out query, key, and value grads.\n        input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)\n        queries_grads              = torch.empty_like(queries)\n        keys_grads                 = input_lin_kv_results_grads[:,:,0,:]\n        values_grads               = input_lin_kv_results_grads[:,:,1,:]\n\n        # Output Linear GEMM - DGRAD\n        # Input1: (data grads)  [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)\n        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))\n        # Output Linear GEMM - WGRAD\n        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)\n        # Input2: (activations) [seql_q*seqs, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )\n        output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),\n                                       matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))\n        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)\n\n        if use_biases_t[0]:\n            output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)\n        else:\n            output_bias_grads = None\n\n        # Matmul2 - DGRAD1\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))\n        # Matmul2 - DGRAD2\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        values_grads   = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))\n\n        # Mask and Scaling for Dropout (not a publically documented op)\n        dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))\n\n        # Softmax Grad (not a publically documented op)\n        softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)\n\n        # Matmul1 - DGRAD1\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] \n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )\n        queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),\n                                      out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n        # Matmul1 - DGRAD2\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)\n        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )\n        keys_grads    = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),\n                                      out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n\n        # Input Q Linear GEMM - DGRAD\n        # input1: (data grads) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)    [embed_dim (1024), embed_dim (1024)] \n        # output:              [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        queries_grads  = queries_grads.transpose(0,1).view(inputs_q.size(0)*inputs_q.size(1), heads_t[0]*head_dim)\n        input_q_grads = torch.mm(queries_grads, input_weights_q)\n        input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))\n        # Input KV Linear GEMM - DGRAD\n        # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]\n        # input2: (weights)    [embed_dim*2 (2048), embed_dim (1024)] \n        # output:              [seql_k, seqs, embed_dim]\n        # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)\n        input_lin_kv_results_grads = input_lin_kv_results_grads.view(inputs_kv.size(0)*inputs_kv.size(1), heads_t[0]*2*head_dim)\n        input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)\n        input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2))\n        # Input Q Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_q*seqs, embed_dim(1024)]\n        # input2: (activations) [seql_q*seqs, embed_dim(1024)] \n        # output:               [embed_dim, embed_dim]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)\n        input_weight_q_grads = torch.mm(queries_grads.transpose(0,1), inputs_q.view(inputs_q.size(0)*inputs_q.size(1), inputs_q.size(2)))\n        # Input KV Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_k*seqs, 2*embed_dim(2048)]\n        # input2: (activations) [seql_k*seqs, embed_dim(1024)] \n        # output:               [2*embed_dim, embed_dim]\n        # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)\n        input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))\n\n        if use_biases_t[0]:\n            input_bias_grads_q = torch.sum(queries_grads, 0)\n            input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)\n        else:\n            input_bias_grads_q = None\n            input_bias_grads_kv = None\n\n        return None, None, None, None,                                            \\\n               input_q_grads, input_kv_grads,                                     \\\n               input_weight_q_grads, input_weight_kv_grads, output_weight_grads,  \\\n               input_bias_grads_q, input_bias_grads_kv, output_bias_grads,        \\\n               None, None\n\nencdec_attn_func = EncdecAttnFunc.apply\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py",
    "content": "import torch\nimport fast_encdec_multihead_attn\n\n\nclass FastEncdecAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        softmax_results,                                                \\\n        dropout_results,                                                \\\n        dropout_mask,                                                   \\\n        matmul2_results,                                                \\\n        outputs =                                                       \\\n            fast_encdec_multihead_attn.forward(                         \\\n                              use_mask,                                 \\\n                              use_time_mask,                            \\\n                              is_training,                              \\\n                              heads,                                    \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              pad_mask if use_mask else null_tensor,    \\\n                              dropout_prob)\n\n        ctx.save_for_backward(heads_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        inputs_q,                                                       \\\n        inputs_kv,                                                      \\\n        input_weights_q,                                                \\\n        input_weights_kv,                                               \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t      = ctx.saved_tensors\n\n        input_q_grads,                                                  \\\n        input_kv_grads,                                                 \\\n        input_weight_q_grads,                                           \\\n        input_weight_kv_grads,                                          \\\n        output_weight_grads =                                           \\\n            fast_encdec_multihead_attn.backward(                        \\\n                              heads_t[0],                               \\\n                              output_grads,                             \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t[0])\n\n        return None, None, None, input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads, None, None\n\nfast_encdec_attn_func = FastEncdecAttnFunc.apply\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py",
    "content": "# Copyright (c) 2017-present, Facebook, Inc.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the LICENSE file in\n# the root directory of this source tree. An additional grant of patent rights\n# can be found in the PATENTS file in the same directory.\n\nimport torch\nimport fast_encdec_multihead_attn_norm_add\n\n\nclass FastEncdecAttnNormAddFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        softmax_results,                                                \\\n        dropout_results,                                                \\\n        dropout_mask,                                                   \\\n        matmul2_results,                                                \\\n        dropout_add_mask,                                               \\\n        outputs =                                                       \\\n            fast_encdec_multihead_attn_norm_add.forward(                \\\n                              use_mask,                                 \\\n                              use_time_mask,                            \\\n                              is_training,                              \\\n                              heads,                                    \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              pad_mask if use_mask else null_tensor,    \\\n                              dropout_prob)\n\n        ctx.save_for_backward(heads_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        inputs_q,                                                       \\\n        inputs_kv,                                                      \\\n        lyr_nrm_gamma_weights,                                          \\\n        lyr_nrm_beta_weights,                                           \\\n        input_weights_q,                                                \\\n        input_weights_kv,                                               \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_add_mask,                                               \\\n        dropout_prob_t         = ctx.saved_tensors\n\n        input_q_grads,                                                  \\\n        input_kv_grads,                                                 \\\n        lyr_nrm_gamma_grads,                                            \\\n        lyr_nrm_beta_grads,                                             \\\n        input_weight_q_grads,                                           \\\n        input_weight_kv_grads,                                          \\\n        output_weight_grads    =                                        \\\n            fast_encdec_multihead_attn_norm_add.backward(               \\\n                              heads_t[0],                               \\\n                              output_grads,                             \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t[0])\n\n        #import pdb; pdb.set_trace()\n        return None, None, None,                                        \\\n               input_q_grads,                                           \\\n               input_kv_grads,                                          \\\n               lyr_nrm_gamma_grads,                                     \\\n               lyr_nrm_beta_grads,                                      \\\n               input_weight_q_grads,                                    \\\n               input_weight_kv_grads,                                   \\\n               output_weight_grads,                                     \\\n               None, None\n\nfast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py",
    "content": "import torch\nimport fast_self_multihead_attn\nimport fast_self_multihead_attn_bias\nimport fast_self_multihead_attn_bias_additive_mask\n\nclass FastSelfAttnFunc(torch.autograd.Function) :\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, pad_mask, mask_additive, dropout_prob):\n        use_biases_t   = torch.tensor([input_biases is not None])\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n        mask_additive_t= torch.tensor([mask_additive])\n\n        if use_biases_t[0]:\n            if not mask_additive:\n                input_lin_results,                                              \\\n                softmax_results,                                                \\\n                dropout_results,                                                \\\n                dropout_mask,                                                   \\\n                matmul2_results,                                                \\\n                outputs =                                                       \\\n                    fast_self_multihead_attn_bias.forward(                           \\\n                                      use_mask,                                 \\\n                                      use_time_mask,                            \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      input_biases,                           \\\n                                      output_biases,                           \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n                ctx.save_for_backward(use_biases_t,                                  \\\n                              heads_t,                          \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              null_tensor,                          \\\n                              null_tensor,                          \\\n                              mask_additive_t,                          \\\n                              input_lin_results,                        \\\n                              inputs,                                   \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n            else:\n                input_lin_results,                                              \\\n                bmm1_results,                                                \\\n                dropout_results,                                                \\\n                dropout_mask,                                                   \\\n                matmul2_results,                                                \\\n                outputs =                                                       \\\n                    fast_self_multihead_attn_bias_additive_mask.forward(                           \\\n                                      use_mask,                                 \\\n                                      use_time_mask,                            \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      input_biases,                           \\\n                                      output_biases,                           \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n                ctx.save_for_backward(use_biases_t,                                  \\\n                                      heads_t,                          \\\n                                      matmul2_results,                          \\\n                                      dropout_results,                          \\\n                                      null_tensor,                          \\\n                                      bmm1_results,                          \\\n                                      pad_mask,                          \\\n                                      mask_additive_t,                          \\\n                                      input_lin_results,                        \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      dropout_mask,                             \\\n                                      dropout_prob_t)\n\n\n        else:\n            input_lin_results,                                              \\\n            softmax_results,                                                \\\n            dropout_results,                                                \\\n            dropout_mask,                                                   \\\n            matmul2_results,                                                \\\n            outputs =                                                       \\\n                fast_self_multihead_attn.forward(                           \\\n                                  use_mask,                                 \\\n                                  use_time_mask,                            \\\n                                  is_training,                              \\\n                                  heads,                                    \\\n                                  inputs,                                   \\\n                                  input_weights,                            \\\n                                  output_weights,                           \\\n                                  pad_mask if use_mask else null_tensor,    \\\n                                  dropout_prob)\n            ctx.save_for_backward(use_biases_t,                                  \\\n                          heads_t,                          \\\n                          matmul2_results,                          \\\n                          dropout_results,                          \\\n                          softmax_results,                          \\\n                          null_tensor,                          \\\n                          null_tensor,                          \\\n                          mask_additive_t,                          \\\n                          input_lin_results,                        \\\n                          inputs,                                   \\\n                          input_weights,                            \\\n                          output_weights,                           \\\n                          dropout_mask,                             \\\n                          dropout_prob_t)\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        use_biases_t,                                                        \\\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        bmm1_results,                                                \\\n        pad_mask,                                                \\\n        mask_additive_t,                                                \\\n        input_lin_results,                                              \\\n        inputs,                                                         \\\n        input_weights,                                                  \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t      = ctx.saved_tensors\n\n        if use_biases_t[0]:\n            if not mask_additive_t[0]:\n                input_grads,                                                    \\\n                input_weight_grads,                                             \\\n                output_weight_grads,                                           \\\n                input_bias_grads,                                                   \\\n                output_bias_grads =                                                    \\\n                    fast_self_multihead_attn_bias.backward(                          \\\n                                      heads_t[0],                               \\\n                                      output_grads,                             \\\n                                      matmul2_results,                          \\\n                                      dropout_results,                          \\\n                                      softmax_results,                          \\\n                                      input_lin_results,                        \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      dropout_mask,                             \\\n                                      dropout_prob_t[0])\n\n            else:\n                input_grads,                                                    \\\n                input_weight_grads,                                             \\\n                output_weight_grads,                                           \\\n                input_bias_grads,                                                   \\\n                output_bias_grads =                                                    \\\n                    fast_self_multihead_attn_bias_additive_mask.backward(                          \\\n                                      heads_t[0],                               \\\n                                      output_grads,                             \\\n                                      matmul2_results,                          \\\n                                      dropout_results,                          \\\n                                      bmm1_results,                          \\\n                                      pad_mask,                          \\\n                                      input_lin_results,                        \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      dropout_mask,                             \\\n                                      dropout_prob_t[0])\n                    \n        else:\n            input_bias_grads = None                                                    \n            output_bias_grads = None\n            input_grads,                                                    \\\n            input_weight_grads,                                             \\\n            output_weight_grads =                                           \\\n                fast_self_multihead_attn.backward(                          \\\n                                  heads_t[0],                               \\\n                                  output_grads,                             \\\n                                  matmul2_results,                          \\\n                                  dropout_results,                          \\\n                                  softmax_results,                          \\\n                                  input_lin_results,                        \\\n                                  inputs,                                   \\\n                                  input_weights,                            \\\n                                  output_weights,                           \\\n                                  dropout_mask,                             \\\n                                  dropout_prob_t[0])\n        return None, None, None, input_grads, input_weight_grads, output_weight_grads,input_bias_grads, output_bias_grads, None, None, None\n\nfast_self_attn_func = FastSelfAttnFunc.apply\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py",
    "content": "import torch\nimport fast_self_multihead_attn_norm_add\n\n\nclass FastSelfAttnNormAddFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        input_lin_results,                                              \\\n        softmax_results,                                                \\\n        dropout_results,                                                \\\n        dropout_mask,                                                   \\\n        matmul2_results,                                                \\\n        dropout_add_mask,                                               \\\n        outputs =                                                       \\\n             fast_self_multihead_attn_norm_add.forward(                 \\\n                              use_mask,                                 \\\n                              use_time_mask,                            \\\n                              is_training,                              \\\n                              heads,                                    \\\n                              inputs,                                   \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              pad_mask if use_mask else null_tensor,    \\\n                              dropout_prob)\n\n        ctx.save_for_backward(heads_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_results,                        \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs,                                   \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_results,                                              \\\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        inputs,                                                         \\\n        lyr_nrm_gamma_weights,                                          \\\n        lyr_nrm_beta_weights,                                           \\\n        input_weights,                                                  \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_add_mask,                                               \\\n        dropout_prob_t          = ctx.saved_tensors\n\n        input_grads,                                                    \\\n        lyr_nrm_gamma_grads,                                            \\\n        lyr_nrm_beta_grads,                                             \\\n        input_weight_grads,                                             \\\n        output_weight_grads    =                                        \\\n            fast_self_multihead_attn_norm_add.backward(                 \\\n                              heads_t[0],                               \\\n                              output_grads,                             \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_results,                        \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs,                                   \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t[0])\n\n        return None, None, None,                                        \\\n               input_grads,                                             \\\n               lyr_nrm_gamma_grads,                                     \\\n               lyr_nrm_beta_grads,                                      \\\n               input_weight_grads,                                      \\\n               output_weight_grads,                                     \\\n               None, None\n\nfast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/mask_softmax_dropout_func.py",
    "content": "import torch\nimport fast_mask_softmax_dropout\nimport fast_additive_mask_softmax_dropout\n\n\nclass MaskSoftmaxDropout(torch.autograd.Function) :\n    @staticmethod\n    def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n        use_mask_t     = torch.tensor([use_mask])\n        mask_additive_t     = torch.tensor([mask_additive])\n\n        if mask_additive:\n            dropout_results,                                                \\\n            dropout_mask,                                                   \\\n            softmax_results =                                                \\\n                    fast_additive_mask_softmax_dropout.forward(                           \\\n                                      use_mask,                                 \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n        else:\n            dropout_results,                                                \\\n            dropout_mask,                                                   \\\n            softmax_results =                                                \\\n                    fast_mask_softmax_dropout.forward(                           \\\n                                      use_mask,                                 \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n        \n        ctx.save_for_backward(\n                              use_mask_t,                                    \\\n                              heads_t,                                 \\\n                              softmax_results,                          \\\n                              dropout_mask,                             \\\n                              pad_mask if use_mask else null_tensor,        \\\n                              mask_additive_t,        \\\n                              dropout_prob_t)\n\n        return dropout_results.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        use_mask_t, \\\n        heads_t,   \\\n        softmax_results,                                                \\\n        dropout_mask,                                              \\\n        pad_mask,                                                   \\\n        mask_additive_t,                                                   \\\n        dropout_prob_t      = ctx.saved_tensors\n\n        if mask_additive_t[0]:\n            input_grads =                                                    \\\n                fast_additive_mask_softmax_dropout.backward(                          \\\n                                  use_mask_t[0],                             \\\n                                  heads_t[0],                             \\\n                                  output_grads,                             \\\n                                  softmax_results,                          \\\n                                  dropout_mask,                             \\\n                                  dropout_prob_t[0])\n        else:\n            input_grads =                                                    \\\n                fast_mask_softmax_dropout.backward(                          \\\n                                  use_mask_t[0],                             \\\n                                  heads_t[0],                             \\\n                                  output_grads,                             \\\n                                  softmax_results,                          \\\n                                  dropout_mask,                             \\\n                                  pad_mask,                             \\\n                                  dropout_prob_t[0])\n        return None, None, input_grads, None, None, None\n\nfast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/self_multihead_attn.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom .self_multihead_attn_func               import self_attn_func\nfrom .fast_self_multihead_attn_func          import fast_self_attn_func\nfrom .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func\nfrom apex.normalization.fused_layer_norm     import FusedLayerNorm\n\nif hasattr(torch._C, '_jit_set_profiling_executor') :\n    torch._C._jit_set_profiling_executor(False)\nif hasattr(torch._C, '_jit_set_profiling_mode') :\n    torch._C._jit_set_profiling_mode(False)\n\n@torch.jit.script\ndef jit_dropout_add(x, residual, prob, is_training):\n    # type: (Tensor, Tensor, float, bool) -> Tensor\n    out = F.dropout(x, p=prob, training=True)\n    out = residual + out\n    return out\n\n\nclass SelfMultiheadAttn(nn.Module):\n    \"\"\"Multi-headed attention.\n\n    See \"Attention Is All You Need\" for more details.\n    \"\"\"\n    def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast', separate_qkv_params=False, mask_additive=False):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n        self.bias = bias\n        self.include_norm_add = include_norm_add\n        self.impl = impl\n        self.scaling = self.head_dim**-0.5\n        self.separate_qkv_params = separate_qkv_params\n        self.mask_additive = mask_additive\n        if mask_additive:\n            assert self.include_norm_add == False, \"additive mask not supported with layer norm\"\n            assert impl == 'default' or (impl == 'fast' and bias), \"additive mask not supported for fast mode without bias\"\n        if separate_qkv_params:\n            self.q_weight  = Parameter(torch.Tensor(embed_dim, embed_dim))\n            self.k_weight  = Parameter(torch.Tensor(embed_dim, embed_dim))\n            self.v_weight  = Parameter(torch.Tensor(embed_dim, embed_dim))\n        else:\n            self.in_proj_weight  = Parameter(torch.Tensor(3*embed_dim, embed_dim))\n        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))\n        if self.bias:\n            if separate_qkv_params:\n                self.q_bias  = Parameter(torch.Tensor(embed_dim))\n                self.k_bias  = Parameter(torch.Tensor(embed_dim))\n                self.v_bias  = Parameter(torch.Tensor(embed_dim))\n            else:\n                self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))\n            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))\n        else:\n            if separate_qkv_params:\n                self.register_parameter('q_bias', None)\n                self.register_parameter('k_bias', None)\n                self.register_parameter('v_bias', None)\n                self.q_bias = None\n                self.k_bias = None\n                self.v_bias = None\n            else:\n                self.register_parameter('in_proj_bias', None)\n                self.in_proj_bias = None\n            self.register_parameter('out_proj_bias', None)\n            self.out_proj_bias = None\n        if self.include_norm_add:\n            if impl == 'fast':\n                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm_beta_weights  = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm               = None\n            else:\n                self.register_parameter('lyr_norm_gamma_weights', None)\n                self.register_parameter('lyr_norm_beta_weights', None)\n                self.lyr_nrm_gamma_weights = None\n                self.lyr_nrm_beta_weights  = None\n                self.lyr_nrm = FusedLayerNorm(embed_dim)\n        self.reset_parameters()\n\n        if self.include_norm_add:\n            if   impl == 'fast'    : self.attn_func = fast_self_attn_norm_add_func\n            elif impl == 'default' : self.attn_func = self_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n        else:\n            if   impl == 'fast'    : self.attn_func = fast_self_attn_func\n            elif impl == 'default' : self.attn_func = self_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n\n    def reset_parameters(self):\n        if self.separate_qkv_params:\n            nn.init.xavier_uniform_(self.q_weight)\n            nn.init.xavier_uniform_(self.k_weight)\n            nn.init.xavier_uniform_(self.v_weight)\n        else:\n            # in_proj_weight has shape [3 * hidden, hidden] but it should be\n            # initialized like a [hidden, hidden] matrix.\n            # sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)\n            # therefore xavier_uniform gain should be set to sqrt(2).\n            nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))\n        nn.init.xavier_uniform_(self.out_proj_weight)\n        if self.bias:\n            if self.separate_qkv_params:\n                nn.init.constant_(self.q_bias, 0.)\n                nn.init.constant_(self.k_bias, 0.)\n                nn.init.constant_(self.v_bias, 0.)\n            else:\n                nn.init.constant_(self.in_proj_bias, 0.)\n            nn.init.constant_(self.out_proj_bias, 0.)\n        if self.include_norm_add:\n            if self.impl == 'fast':\n                nn.init.ones_(self.lyr_nrm_gamma_weights)\n                nn.init.zeros_(self.lyr_nrm_beta_weights)\n            else:\n                self.lyr_nrm.reset_parameters()\n\n    def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):\n        \"\"\"Input shape: Time x Batch x Channel\n\n        Self-attention can be implemented by passing in the same arguments for\n        query, key and value. Future timesteps can be masked with the\n        `mask_future_timesteps` argument. Padding elements can be excluded from\n        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:\n        batch x src_len, where padding elements are indicated by 1s.\n        \"\"\"\n        if self.separate_qkv_params:\n            input_weights = torch.cat([self.q_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.k_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.v_weight.view(self.num_heads,1,self.head_dim,self.embed_dim)], dim=1).reshape(3*self.embed_dim,self.embed_dim).contiguous()\n        else: \n            input_weights = self.in_proj_weight\n        if self.bias:\n            if self.separate_qkv_params:\n                input_bias = torch.cat([self.q_bias.view(self.num_heads,1,self.head_dim), self.k_bias.view(self.num_heads,1,self.head_dim), self.v_bias.view(self.num_heads,1,self.head_dim)],dim=1).reshape(3*self.embed_dim).contiguous()\n            else:\n                input_bias = self.in_proj_bias\n        else:\n            input_bias=None        \n        if key_padding_mask is not None:\n            assert (attn_mask is None), \"ERROR attn_mask and key_padding_mask should not be both defined!\"\n            mask = key_padding_mask\n        elif attn_mask is not None:\n            assert self.mask_additive == False, \"additive mask not supported for time mask\"\n            mask = attn_mask\n        else:\n            mask = None\n\n        if self.include_norm_add:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,\n                                         self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,\n                                         input_weights, self.out_proj_weight, mask, self.dropout)\n            else:\n                lyr_nrm_results = self.lyr_nrm(query)\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,\n                                         input_weights, self.out_proj_weight,\n                                         input_bias, self.out_proj_bias,\n                                         mask, self.dropout)\n                if is_training:\n                    outputs = jit_dropout_add(outputs, query, self.dropout, is_training)\n                else:\n                    outputs = outputs + query\n        else:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,\n                                         input_weights, self.out_proj_weight, input_bias, self.out_proj_bias, mask, self.mask_additive, self.dropout)\n            else:\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,\n                                         input_weights, self.out_proj_weight,\n                                         input_bias, self.out_proj_bias,\n                                         mask, self.mask_additive, self.dropout)\n\n        return outputs,None\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/multihead_attn/self_multihead_attn_func.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nclass SelfAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, scale, inputs,\n                input_weights, output_weights,\n                input_biases, output_biases,\n                mask, is_additive_mask, dropout_prob):\n        use_biases_t   = torch.tensor([input_biases is not None])\n        heads_t        = torch.tensor([heads])\n        scale_t        = torch.tensor([scale])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        head_dim       = inputs.size(2) // heads\n\n        # Input Linear GEMM\n        # input1: (activations) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_q, seqs, embed_dim*3]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)\n        if use_biases_t[0]:\n            input_lin_results = torch.addmm(input_biases,\n                                            inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                                            input_weights.transpose(0,1),\n                                            beta=1., alpha=1.)\n        else:\n            input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))\n        input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))\n\n        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]\n        # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]\n        input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads, 3, head_dim)\n        queries = input_lin_results[:,:,0,:]\n        keys    = input_lin_results[:,:,1,:]\n        values  = input_lin_results[:,:,2,:]\n\n        # Matmul1 Batched GEMMs\n        # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification\n        # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of\n        # a separate elementwise operation.\n        # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)\n        # Input2: (Keys)    [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # output:           [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))\n        matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])\n\n        if mask is not None:\n            # Self Attention Time Mask\n            if use_time_mask:\n                assert (len(mask.size()) == 2), \"Timing mask is not 2D!\"\n                assert (mask.size(0) == mask.size(1)), \"Sequence length should match!\"\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))\n            # Key Padding Mask\n            else:\n                batches,seql_q,seql_k = matmul1_results.size()\n                seqs = int(batches / heads)\n                matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)\n                if is_additive_mask:\n                    matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)\n                else:\n                    mask = mask.to(torch.bool)\n                    matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))\n                matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)\n\n        softmax_results = F.softmax(matmul1_results, dim=-1)\n\n        # Dropout - is not executed for inference\n        if is_training:\n            dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))\n        else:\n            dropout_results = softmax_results\n            dropout_mask    = null_tensor\n\n        # Matmul2 Batched GEMMs\n        # The output tensor specification is needed here to specify the non-standard output.\n        # Given that pytorch cannot currently perform autograd with an output tensor specified,\n        # this requires a backward pass specified.\n        # Input1: from_softmax [seqs*heads, seql_q, seql_k]\n        # Input2: (values)     [seql_v, seqs*heads, head_dim] transpose(0,1)\n        # Output:              [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)\n        matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)\n        matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)\n        matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))\n\n        # Output Linear GEMM\n        # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        if use_biases_t[0]:\n            outputs = torch.addmm(output_biases,\n                                  matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                                  output_weights.transpose(0,1),\n                                  beta=1., alpha=1.)\n        else:\n            outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))\n        outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))\n\n        ctx.save_for_backward(use_biases_t,                             \\\n                              heads_t,                                  \\\n                              scale_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_results,                        \\\n                              inputs,                                   \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        use_biases_t,                                                   \\\n        heads_t,                                                        \\\n        scale_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_results,                                              \\\n        inputs,                                                         \\\n        input_weights,                                                  \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t          = ctx.saved_tensors\n\n        head_dim                = inputs.size(2) // heads_t[0]\n\n        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]\n        # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]\n        input_lin_results       = input_lin_results.view(inputs.size(0), inputs.size(1)*heads_t[0], 3, head_dim)\n        queries                 = input_lin_results[:,:,0,:]\n        keys                    = input_lin_results[:,:,1,:]\n        values                  = input_lin_results[:,:,2,:]\n\n        # Slice out q,k,v from one big set of gradients entering the input linear's bprop  (should only impact meta data, no copies!)\n        # The gradients are identical in size to the Input Linear outputs.\n        # The tensor is declared before hand to properly slice out query, key, and value grads.\n        input_lin_results_grads = torch.empty_like(input_lin_results)\n        queries_grads           = input_lin_results_grads[:,:,0,:]\n        keys_grads              = input_lin_results_grads[:,:,1,:]\n        values_grads            = input_lin_results_grads[:,:,2,:]\n\n        # Output Linear GEMM - DGRAD\n        # Input1: (data grads)  [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)\n        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))\n        # Output Linear GEMM - WGRAD\n        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)\n        # Input2: (activations) [seql_q*seqs, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )\n        output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),\n                                       matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))\n        output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)\n\n        if use_biases_t[0]:\n            output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)\n        else:\n            output_bias_grads = None\n\n        # Matmul2 - DGRAD1\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))\n        # Matmul2 - DGRAD2\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        values_grads   = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))\n\n        # Mask and Scaling for Dropout (not a publically documented op)\n        dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))\n\n        # Softmax Grad (not a publically documented op)\n        softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)\n\n        # Matmul1 - DGRAD1\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] \n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )\n        queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),\n                                      out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n        # Matmul1 - DGRAD2\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)\n        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )\n        keys_grads    = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),\n                                      out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n\n        # Input Linear GEMM - DGRAD\n        # input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]\n        # input2: (weights)    [embed_dim*3 (3072), embed_dim (1024)] \n        # output:              [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        input_lin_results_grads = input_lin_results_grads.view(inputs.size(0)*inputs.size(1), heads_t[0]*3*head_dim)\n        input_grads = torch.mm(input_lin_results_grads, input_weights)\n        input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))\n        # Input Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_q*seqs, 3*embed_dim(3072)]\n        # input2: (activations) [seql_q*seqs, embed_dim(1024)] \n        # output:               [3*embed_dim, embed_dim]\n        # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)\n        input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))\n\n        if use_biases_t[0]:\n            input_bias_grads = torch.sum(input_lin_results_grads, 0)\n        else:\n            input_bias_grads = None\n\n        return None, None, None, None,                   \\\n               input_grads,                              \\\n               input_weight_grads, output_weight_grads,  \\\n               input_bias_grads, output_bias_grads,      \\\n               None, None\n\nself_attn_func = SelfAttnFunc.apply\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/__init__.py",
    "content": "from .fp16_optimizer import FP16_Optimizer\nfrom .fused_adam import FusedAdam\nfrom .fused_lamb import FusedLAMB\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/distributed_fused_adam.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nimport torch.distributed.distributed_c10d as c10d\n\nclass DistributedFusedAdam(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n    \n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n    \n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        overlap_reductions(boolean, optional): whether to overlap reductions\n            with bprop (default: True)\n        step_supports_amp_scaling(boolean, optional): whether to use customized\n            gradient unscaling logic (default: True)\n        num_process_groups (integer, optional): number of process groups in\n            the app (default: 1)\n        current_process_group (object, optional): the process group to work on\n            (default: None)\n        process_group_id (integer, optional): process group id (default: 0)\n        process_group_size (integer, optional): size of process group\n            (default: 0)\n        clip_grad_norm (boolean, optional): whether to handle gradient clipping\n            (default: True)\n        model_parallel (boolean, optional): whether model parallelism is used\n            (default: False)\n\n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction=True, betas=(0.9, 0.999),\n                 eps=1e-8, eps_inside_sqrt=False,\n                 weight_decay=0., max_grad_norm=0.,\n                 amsgrad=False, flat_mt=False,\n                 overlap_reductions=True,\n                 compute_L2_grad_norm=False,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,\n                 dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,\n                 predivide=True, e5m2_allgather=False,\n                 do_not_flatten_model=False,\n                 step_supports_amp_scaling=True,\n                 num_process_groups=1,\n                 current_process_group=None,\n                 process_group_id=0,\n                 process_group_size=0,\n                 clip_grad_norm=True,\n                 model_parallel=False):\n        global fused_adam_cuda, distributed_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n        distributed_adam_cuda = importlib.import_module(\"distributed_adam_cuda\")\n        self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n\n        if amsgrad:\n            raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')\n\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(DistributedFusedAdam, self).__init__(params, defaults)\n\n        # Misc\n        self.eps_mode = 0 if eps_inside_sqrt else 1\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        self._has_overflow = False\n        self._step_supports_amp_scaling = step_supports_amp_scaling\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._num_chunks = dwu_num_chunks\n        self._predivide = predivide\n        self._e5m2_allgather = e5m2_allgather\n        self._do_not_flatten_model = do_not_flatten_model\n        self._compute_L2_grad_norm = compute_L2_grad_norm\n        self._L2_grad_norm = None\n        self._flat_mt = flat_mt\n        self._init_done = False\n        self._resume_from_checkpoint = False\n        self._step = 0\n\n        # Process group related\n        self._clip_grad_norm = clip_grad_norm\n        self._model_parallel = model_parallel\n        self._num_process_groups = num_process_groups\n        self._current_process_group = current_process_group if current_process_group is not None else c10d._get_default_group()\n        self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())\n        self._process_group_id = process_group_id\n        self._process_group_size = torch.cuda.device_count() if process_group_size <= 0 else process_group_size\n        self._world_size = self._process_group_size # world: the current process group\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._num_groups = self._world_size // self._group_size\n        self._global_rank = torch.distributed.get_rank()\n        self._world_rank = self._global_rank // self._num_process_groups\n        self._group_rank = self._world_rank % self._group_size\n        #print(\"world_size:\", self._world_size, \", group_size:\", self._group_size, \", num_groups:\", self._num_groups, \", global_rank:\", self._global_rank, \", world_rank:\", self._world_rank, \", group_rank:\", self._group_rank)\n        self._num_rs_pg = dwu_num_rs_pg\n        self._num_ar_pg = dwu_num_ar_pg\n        self._num_ag_pg = dwu_num_ag_pg\n\n        # Master weight, moment, gradient buffers\n        self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None\n\n    def _first_step_init(self):\n        p_offset = 0\n        p_i = 0\n        self._model_params = []\n        self._grads_info = []\n        self._grad_accs = []\n        self._group_properties = []\n        for group in self.param_groups:\n            self._param_group = group\n            prev = None\n            beta1, beta2 = group['betas']\n            bias_correction = 1 if group['bias_correction'] else 0\n            eps = group['eps']\n            weight_decay = group['weight_decay']\n            for p in group['params']:\n                # broadcast from rank 0 of current process group\n                torch.distributed.broadcast(p, src=self._available_ranks[0], group=self._current_process_group)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                # Multiple param groups support: \n                # store one hyperparam item per parameter tensor\n                self._group_properties.append((\n                    beta1,\n                    beta2,\n                    bias_correction,\n                    eps,\n                    weight_decay\n                    ))\n                p_grads_size = p.numel()\n                def wrapper(param, param_i, param_grads_size, param_offset):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n                wrapper(p, p_i, p_grads_size, p_offset)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._grads_info)\n        self._grads = []\n        if self._overlap_reductions:\n            self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._chunk_size = self._block_size // self._num_chunks\n        self._shard_size = self._chunk_size // self._group_size\n        #print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        #print(self._low_param_i)\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size\n        # initialize master weights, moments buffers if not loaded from checkpoint\n        if self._fp32_p is None:\n            self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        # FIXME: Rethink fp16 label since it's either uint8 or fp16\n        self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')\n\n        self._individual_flat_grads = []\n        for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):\n            self._individual_flat_grads.append(self._flat_grads[grads_info[\"param_offset\"]:grads_info[\"param_offset\"]+grads_info[\"param_grads_size\"]].view_as(p))\n\n        def _flat_split(p):\n            def __blockify(p):\n                return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            def __shardify(p):\n                return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            list_of_blocks = __blockify(self._flat_grads)\n            list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]\n            list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]\n            return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards\n        self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)\n        def _full_packed_split(p):\n            def __shardify(p):\n                return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]\n            def __blockify(p):\n                return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_mega_shards = __shardify(p)\n            list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]\n            list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]\n            return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks\n        self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)\n        def _packed_split(p):\n            def __packed_blockify(p):\n                packed_block_size = self._num_chunks*self._shard_size\n                return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]\n            def __packed_chunkify(p):\n                # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_blocks = __packed_blockify(p)\n            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]\n            return list_of_blocks, list_of_list_of_chunks\n        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)\n        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)\n        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)\n        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)\n        self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)\n\n        # This paragraph does two things:\n        # 1) Copy model parameters into master buffer\n        # 2) Create tensor lists for unpacking new parameter tensor after all-gather\n        self._packed_flat_to_model_params = []\n        self._contrib_tensor_list = []\n        self._contrib_group_properties = []\n        self._non_parallel_grads = []\n        for shard_id in range(self._group_size):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size\n                    flat_shard_end = flat_shard_start + self._shard_size\n                    for (p, grads_info, group_props) in zip(self._model_params, self._grads_info, self._group_properties):\n                        flat_grad_start = grads_info[\"param_offset\"]\n                        flat_grad_end = flat_grad_start + grads_info[\"param_grads_size\"]\n                        clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)\n                        clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)\n                        if clipped_start < clipped_end:\n                            grad_offset = clipped_start - flat_grad_start\n                            grad_length = clipped_end - clipped_start\n                            shard_offset = clipped_start - flat_shard_start\n                            model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]\n                            new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                            self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )\n                            if shard_id == self._group_rank:\n                                # copy model parameters into master buffer\n                                master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                #print(\"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s\" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))\n                                if not self._resume_from_checkpoint:\n                                    master_param_fragment.copy_(model_param_fragment)\n                                self._contrib_group_properties.append(group_props)\n                                self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, g, p_copy\n                                if self._model_parallel and hasattr(p, 'model_parallel') and not p.model_parallel:\n                                    self._non_parallel_grads.append(opti_state_g_fragment)\n\n        p, m, v, g, p_copy = list(zip(*self._contrib_tensor_list))\n        self._contrib_tensor_list = [p, m, v, g, p_copy]\n\n        math_type = self._fp32_p.dtype\n        beta1, beta2, bias_correction, epsilon, decay = list(zip(*self._contrib_group_properties))\n        self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')\n        self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')\n        self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')\n        self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')\n        self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')\n\n        p_in, p_out = zip(*self._packed_flat_to_model_params)\n        self._packed_flat_to_model_params = [p_in, p_out]\n\n        if self._num_groups > 1:\n            self._ar_pg = []\n            for i in range(self._num_process_groups):\n                # gather global ranks of all members of the current process group\n                ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]\n                for j in range(self._group_size):\n                    ar_idx = [j+k*self._group_size for k in range(self._num_groups)]\n                    ar_rank = [ranks[k] for k in ar_idx]\n                    #if self._global_rank in ar_rank:\n                    #    print(\"group for all reduce, ranks:\", ar_rank)\n                    for _ in range(self._num_ar_pg):\n                        grp = torch.distributed.new_group(ranks=ar_rank)\n                        if self._global_rank in ar_rank:\n                            self._ar_pg.append(grp)\n            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            for ar_pg in self._ar_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)\n\n        self._rs_pg, rs_ranks = [],[]\n        for i in range(self._num_process_groups):\n            ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]\n            for j in range(self._num_groups):\n                rs_idx = [j*self._group_size+k for k in range(self._group_size)]\n                rs_rank = [ranks[k] for k in rs_idx]\n                #if self._global_rank in rs_rank:\n                #    print(\"group for reduce scatter, ranks:\", rs_rank)\n                for _ in range(self._num_rs_pg):\n                    grp = torch.distributed.new_group(ranks=rs_rank)\n                    if self._global_rank in rs_rank:\n                        self._rs_pg.append(grp)\n                if self._compute_L2_grad_norm:\n                    l2_grad_norm_pg = torch.distributed.new_group(ranks=rs_rank)\n                    if self._global_rank in rs_rank:\n                        self._l2_grad_norm_pg = l2_grad_norm_pg\n                        torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)\n        self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]\n        for rs_pg in self._rs_pg:\n            torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)\n\n        if self._num_ag_pg == 0:\n            self._ag_pg = self._rs_pg\n            self._ag_st = self._rs_st\n            self._num_ag_pg = self._num_rs_pg\n        else:\n            self._ag_pg = []\n            for i in range(self._num_process_groups):\n                ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]\n                for j in range(self._num_groups):\n                    ag_rank = rs_ranks[j]\n                    #if self._global_rank in ag_rank:\n                    #    print(\"group for all gather, ranks:\", ag_rank)\n                    for _ in range(self._num_ag_pg):\n                        grp = torch.distributed.new_group(ranks=ag_rank)\n                        if self._global_rank in ag_rank:\n                            self._ag_pg.append(grp)\n            self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n            for ag_pg in self._ag_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)\n        self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None\n        self._completion_st = torch.cuda.Stream()\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n    def _init_everything(self):\n        if not self._init_done:\n            self._first_step_init()\n            self._init_done = True\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def _pipeline_block_reductions(self, block_id):\n        self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n\n        # Reduction within each node\n        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n        # The output format is the same as the fp32 master parameters\n        works = [None]*self._num_chunks\n        for chunk_id in range(self._num_chunks):\n            glob_chunk_id = block_id * self._num_chunks + chunk_id\n            rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]\n            rs_stream.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(rs_stream):\n                works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)\n\n        # Reduction across nodes for each rank\n        if self._num_groups > 1:\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                with torch.cuda.stream(ar_stream):\n                    works[chunk_id].wait()\n                    works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n        self._reductions_works[block_id] = works\n\n        # Optionally compute L2 grad norm\n        if self._compute_L2_grad_norm and block_id == 0:\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                for block_id in range(self._num_blocks):\n                    for chunk_id in range(self._num_chunks):\n                        self._reductions_works[block_id][chunk_id].wait()\n                # Since the packed format is contiguous after reductions, only one norm is needed\n                l2_grad_norm_sq = torch.empty([1], device='cuda')\n                l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2\n                torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)\n                # for model_parallel_rank=0, keep all gradients\n                # for the rest, subtract non_parallel gradients\n                if self._model_parallel and self._process_group_id: # non zero model_parallel_rank\n                    non_parallel_grad_norm_sq = torch.zeros([1], device='cuda')\n                    if len(self._non_parallel_grads): # non parallel grads exit\n                        non_parallel_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm,\n                                                                         self._overflow_buf,\n                                                                         [self._non_parallel_grads], False)[0]**2\n                    torch.distributed.all_reduce(non_parallel_grad_norm_sq, group=self._l2_grad_norm_pg)\n                    l2_grad_norm_sq = l2_grad_norm_sq - non_parallel_grad_norm_sq\n                self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()\n\n    def __launch_step_kernel(self):\n        # If self._clip_grad_norm is False, we assume gradient clipping already \n        # happened outside the optimizer and self._global_scale has already \n        # been set to the combined scale, i.e. it's no longer the current loss\n        # scale used by the loss scaler. \n        # For model parallelism cases in which we need to get global gradient \n        # norm via all-reduce outside the optimizer to do the clipping. \n        combined_scale = self._global_scale\n        if self._clip_grad_norm and self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        \n        self._step += 1\n        multi_tensor_applier(distributed_adam_cuda.multi_tensor_fused_adam,\n                self._overflow_buf,\n                self._contrib_tensor_list, # p, m, v, g, p_copy\n                self._contrib_beta1,\n                self._contrib_beta2,\n                self._contrib_bias_correction,\n                self._contrib_epsilon,\n                self._contrib_weight_decay,\n                self._param_group['lr'],\n                combined_scale,\n                self._step,\n                self.eps_mode)\n\n    def _pipeline_step(self):\n        # Call step kernel once per step\n        # Call all-gather once per step\n        with torch.cuda.stream(self._completion_st):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    self._reductions_works[block_id][chunk_id].wait()\n            self.__launch_step_kernel()\n            torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _flatten_grad_mt(self, scale):\n        if self._flat_mt and len(self._grads) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads)),\n                    scale)\n            self._grads = []\n\n    def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):\n        # handle overlapped reductions\n        if self._flat_mt:\n            self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )\n        else:\n            torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])\n        self._grads_generated[param_i]=True\n        if not self._last_step:\n            if self._overlap_reductions:\n                flush_block = self._get_flush_block()\n                while flush_block:\n                    block_id = flush_block[0] // self._block_size\n                    self._pipeline_block_reductions(block_id)\n                    flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def has_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Clears the overflow flag.\n        \"\"\"\n        has_overflow = self._has_overflow\n        self._has_overflow = False\n        return has_overflow\n\n    @property\n    def peek_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Does not clear overflow flag.\n        \"\"\"\n        return self._has_overflow\n\n    def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):\n        \"\"\"Strided check for overflow.\n        You can get status by calling has_overflow.\n        \"\"\"\n        if start >= 0 and start < end:\n            out_p = output_params[start:end]\n        else:\n            out_p = output_params\n        fused_adam_cuda.strided_check_finite(self._overflow_buf,\n                out_p,\n                stride,\n                1 if clear else 0)\n        self._has_overflow = False if self._overflow_buf.item() == 0 else True\n        return self._has_overflow\n\n    @property\n    def L2_grad_norm(self):\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n            return self._L2_grad_norm\n        else:\n            return None\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n        self._init_everything()\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, grad_generated in enumerate(self._grads_generated):\n                if not grad_generated:\n                    grad_info = self._grads_info[param_i]\n                    param_offset = grad_info[\"param_offset\"]\n                    param_size = grad_info[\"param_grads_size\"]\n                    self._flat_grads[param_offset:param_offset+param_size].zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            for block_id in range(self._num_blocks-1,-1,-1):\n                self._pipeline_block_reductions(block_id)\n\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def step(self, closure=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        self._pipeline_step()\n\n        with torch.cuda.stream(self._completion_st):\n            # Copy self._new_params to model params\n            multi_tensor_applier(\n                    fused_adam_cuda.maybe_cast_mt,\n                    self._overflow_buf,\n                    self._packed_flat_to_model_params)\n\n        torch.cuda.current_stream().wait_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        return loss\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        # save step, master weights and first/second moments\n        state_dict = {}\n        state_dict['step'] = self._step\n        state_dict['fp32_p'] = self._fp32_p\n        state_dict['fp32_m'] = self._fp32_m\n        state_dict['fp32_v'] = self._fp32_v\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``optimizer.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # restore step, master weights and first/second moments\n        self._step = state_dict['step']\n        self._fp32_p = state_dict['fp32_p'].to(device=\"cuda\")\n        self._fp32_m = state_dict['fp32_m'].to(device=\"cuda\")\n        self._fp32_v = state_dict['fp32_v'].to(device=\"cuda\")\n        self._resume_from_checkpoint = True\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/distributed_fused_adam_v2.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass DistributedFusedAdamV2(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        use_mt (boolean, optional): use multi tensor apply for lower launch\n            latency. (default: False)\n        overlap_reductions(boolean, optional): whether to overlap reductions\n            with bprop (default: True)\n        num_prestats (integer, optional): number of fp64 stats that will be\n            reduced during first fp16 gradient reduction block. \n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True,\n                 betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,\n                 weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,\n                 amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,\n                 compute_L2_grad_norm=False, distributed_weight_update=0,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,\n                 dwu_num_ag_pg=0, revert_method=1, flat_mt=False,\n                 dwu_num_chunks=4, predivide=True, e5m2_allgather=False,\n                 do_not_flatten_model=False):\n        global fused_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n\n        self._amp_scale_adjustment = amp_scale_adjustment\n\n        if use_mt:\n            raise RuntimeError('DistributedFusedAdam does not support use_mt.')\n        if amsgrad:\n            raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')\n\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(DistributedFusedAdamV2, self).__init__(params, defaults)\n        self.eps_mode = 0 if  eps_inside_sqrt else 1\n\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        self._has_overflow = False\n\n        assert (len(self.param_groups) == 1), \"More than one parameter group is not supported.\"\n\n        # Way to revert a step\n        # 3 -> undo kernel + double buffer (debug, print norm of difference)\n        # 2 -> double buffer fp32 parameters\n        # 1 -> undo kernel\n        self._revert_method = revert_method\n        if self._revert_method > 1:\n            print(\"revert_method -> double buffer fp32 parameters, will consume more memory\")\n\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._num_chunks = dwu_num_chunks\n        self._predivide = predivide\n        self._e5m2_allgather = e5m2_allgather\n        self._do_not_flatten_model = do_not_flatten_model\n        self._full_pipeline = full_pipeline\n        self._compute_L2_grad_norm = compute_L2_grad_norm\n        self._L2_grad_norm = None\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._world_size = torch.distributed.get_world_size()\n        self._num_groups = self._world_size // self._group_size\n        self._rank_in_group = torch.distributed.get_rank() % self._group_size\n\n        p_offset = 0\n        p_i = 0\n        self._param_state = None\n        self._model_params = []\n        self._grads_info = []\n        self._grad_accs = []\n        for group in self.param_groups:\n            self._param_group = group\n            prev = None\n            for p in group['params']:\n                torch.distributed.broadcast(p,0)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                state = self.state[p]\n                if len(state) == 0:\n                    state['step'] = 0\n                if self._param_state is None:\n                    self._param_state = state\n                p_grads_size = p.numel()\n                def wrapper(param, param_i, param_grads_size, param_offset):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n                wrapper(p, p_i, p_grads_size, p_offset)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._grads_info)\n        self._flat_mt = flat_mt\n        self._grads = []\n        if self._overlap_reductions:\n            self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._shard_size = self._block_size // self._group_size\n        self._chunk_size = self._shard_size // self._num_chunks\n        print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size,self._chunk_size))\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        print(self._low_param_i)\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._mega_shard_size = self._num_blocks * self._num_chunks * self._chunk_size\n        self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        # FIXME: Rethink fp16 label since it's either uint8 or fp16\n        self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')\n\n        self._individual_flat_grads = []\n        for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):\n            self._individual_flat_grads.append(self._flat_grads[grads_info[\"param_offset\"]:grads_info[\"param_offset\"]+grads_info[\"param_grads_size\"]].view_as(p))\n\n        def _flat_split(p):\n            def __blockify(p):\n                return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __shardify(p):\n                return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._group_size)]\n            list_of_blocks = __blockify(self._flat_grads)\n            list_of_list_of_shards = [__shardify(block) for block in list_of_blocks]\n            list_of_list_of_list_of_chunks = [[__chunkify(shard) for shard in shards] for shards in list_of_list_of_shards]\n            return list_of_blocks, list_of_list_of_shards, list_of_list_of_list_of_chunks\n        self._flat_grads_blocks, self._flat_grads_shards, self._flat_grads_chunks = _flat_split(self._flat_grads)\n        def _full_packed_split(p):\n            def __shardify(p):\n                return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]\n            def __blockify(p):\n                return [p[block_id*self._num_chunks*self._chunk_size:(block_id+1)*self._num_chunks*self._chunk_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            list_of_mega_shards = __shardify(p)\n            list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]\n            list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]\n            return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks\n        self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)\n        def _packed_split(p):\n            def __packed_blockify(p):\n                packed_block_size = self._num_chunks*self._chunk_size\n                return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]\n            def __packed_chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            list_of_blocks = __packed_blockify(p)\n            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]\n            return list_of_blocks, list_of_list_of_chunks\n        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)\n        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)\n        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)\n        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)\n        self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)\n\n        # current arrangement\n        # \n        # self._flat_grads\n        # self._flat_grads_blocks [x self._num_blocks, self._block_size]\n        # self._flat_grads_chunks [x self._num_chunks, self._chunk_size]\n        # self._flat_grads_shards [x self._group_size, self._shard_size]\n        #\n        # self._new_params\n        # self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]\n        # self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]\n        # self._new_params_mega_chunks [x self._num_chunks, self._shard_size]\n        #\n        # self._fp32_p\n        # self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]\n        # self._fp32_p_chunks [x self._num_chunks, self._shard_size]\n        # each chunk contains one shard\n        # same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g\n        #\n        # Usage:\n        # \n        # for chunk_id in range(self._num_chunks):\n        #   works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)\n        #\n        # ----------------------------------------------------------------------------------------\n        #\n        # new arrangement\n        #\n        # NB! New equations for self._shard_size and self._chunk_size\n        #\n        # self._flat_grads\n        # self._flat_grads_blocks [x self._num_blocks, self._block_size]\n        # self._flat_grads_shards [x self._group_size, self._shard_size]\n        # self._flat_grads_chunks [x self._num_chunks, self._chunk_size]\n        #\n        # self._new_params\n        # self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]\n        # self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]\n        # self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]\n        #\n        # self._fp32_p\n        # self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]\n        # self._fp32_p_chunks [x self._num_chunks, self._chunk_size]\n        # same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g\n        #\n        # Usage:\n        #\n        # work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)\n        # for chunk_id in range(self._num_chunks):\n        #   work.wait()\n        #   works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)\n        # or\n        # work.wait()\n        # works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)\n        #\n\n        # This paragraph does two things:\n        # 1) Copy model parameters into master buffer\n        # 2) Create tensor lists for unpacking new parameter tensor after all-gather\n        self._packed_flat_to_model_params = []\n        for shard_id in range(self._group_size):\n            for block_id in range(self._num_blocks):\n                flat_shard_start = (block_id * self._group_size + shard_id) * self._shard_size\n                flat_shard_end = flat_shard_start + self._shard_size\n                for p, grads_info in zip(self._model_params, self._grads_info):\n                    flat_grad_start = grads_info[\"param_offset\"]\n                    flat_grad_end = flat_grad_start + grads_info[\"param_grads_size\"]\n                    clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)\n                    clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)\n                    if clipped_start < clipped_end:\n                        grad_offset = clipped_start - flat_grad_start\n                        grad_length = clipped_end - clipped_start\n                        shard_offset = clipped_start - flat_shard_start\n                        model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]\n                        new_param_packed_fragment = self._new_params_mega_blocks[shard_id][block_id][shard_offset:shard_offset+grad_length]\n                        self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )\n                        if shard_id == self._rank_in_group:\n                            # copy model parameters into master buffer\n                            master_param_fragment = self._fp32_p_blocks[block_id][shard_offset:shard_offset+grad_length]\n                            print(\"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s\" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))\n                            master_param_fragment.copy_(model_param_fragment)\n\n        p_in, p_out = zip(*self._packed_flat_to_model_params)\n        self._packed_flat_to_model_params = [p_in, p_out]\n\n        self._distributed_weight_update = distributed_weight_update # Is this still needed?\n        self._num_rs_pg = dwu_num_rs_pg\n        self._num_ar_pg = dwu_num_ar_pg\n        self._num_ag_pg = dwu_num_ag_pg\n        if self._num_groups > 1:\n            self._ar_pg = []\n            for dev_i in range(self._group_size):\n                ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]\n                for i in range(self._num_ar_pg):\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if torch.distributed.get_rank() in ranks:\n                        self._ar_pg.append(grp)\n            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            for ar_pg in self._ar_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)\n        rs_ranks = []\n        for group_i in range(self._num_groups):\n            rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])\n        self._rs_pg = []\n        for group_i in range(self._num_groups):\n            ranks = rs_ranks[group_i]\n            for i in range(self._num_rs_pg):\n                grp = torch.distributed.new_group(ranks=ranks)\n                if torch.distributed.get_rank() in ranks:\n                    self._rs_pg.append(grp)\n            if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:\n                self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)\n                torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)\n        self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]\n        for rs_pg in self._rs_pg:\n            torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)\n        if self._num_ag_pg == 0:\n            self._ag_pg = self._rs_pg\n            self._ag_st = self._rs_st\n            self._num_ag_pg = self._num_rs_pg\n        else:\n            self._ag_pg = []\n            for group_i in range(self._num_groups):\n                ranks = rs_ranks[group_i]\n                for i in range(self._num_ag_pg):\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if torch.distributed.get_rank() in ranks:\n                        self._ag_pg.append(grp)\n            self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n            for ag_pg in self._ag_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)\n        self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None\n        self._completion_st = torch.cuda.Stream()\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def _pipeline_block_reductions(self, block_id):\n        self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n\n        # Reduction within each node\n        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n        # The output format is the same as the fp32 master parameters\n        works = [None]*self._num_chunks\n        rs_stream = self._rs_st[block_id%self._num_rs_pg]\n        rs_stream.wait_stream(torch.cuda.current_stream())\n        with torch.cuda.stream(rs_stream):\n            rs_work = torch.distributed.reduce_scatter(self._fp16_g_blocks[block_id],self._flat_grads_shards[block_id],group=self._rs_pg[block_id%self._num_rs_pg],async_op=True,no_copy=True)\n            for chunk_id in range(self._num_chunks):\n                works[chunk_id] = rs_work\n\n        # Reduction across nodes for each rank\n        if self._num_groups > 1:\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                with torch.cuda.stream(ar_stream):\n                    rs_work.wait()\n                    works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n        self._reductions_works[block_id] = works\n\n        # Optionally compute L2 grad norm\n        if self._compute_L2_grad_norm and block_id == 0:\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                for block_id in range(self._num_blocks):\n                    for chunk_id in range(self._num_chunks):\n                        self._reductions_works[block_id][chunk_id].wait()\n                # Since the packed format is contiguous after reductions, only one norm is needed\n                l2_grad_norm_sq = torch.empty([1], device='cuda')\n                l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2\n                torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)\n                self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()\n\n    def __launch_step_kernel(self, p, p_copy, m, v, g):\n        combined_scale = self._global_scale\n        if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        bias_correction = 1 if self._param_group['bias_correction'] else 0\n        beta1, beta2 = self._param_group['betas']\n        fused_adam_cuda.reversible_adam(\n                p, p_copy, m, v, g,\n                self._param_group['lr'],\n                beta1,\n                beta2,\n                self._param_group['eps'],\n                combined_scale,\n                self._param_state['step']+1,\n                self.eps_mode,\n                bias_correction,\n                self._param_group['weight_decay'])\n\n    def _pipeline_block_step(self, block_id):\n        # Call step kernel once per block\n        ag_stream = self._ag_st[block_id%self._num_ag_pg]\n        with torch.cuda.stream(ag_stream):\n            for chunk_id in range(self._num_chunks):\n                self._reductions_works[block_id][chunk_id].wait()\n            self.__launch_step_kernel(\n                self._fp32_p_blocks[block_id],\n                self._fp16_p_blocks[block_id],\n                self._fp32_m_blocks[block_id],\n                self._fp32_v_blocks[block_id],\n                self._fp16_g_blocks[block_id])\n        # Call all-gather once per step.\n        # FIXME: Determine which is faster, one all-gather per block or a single all-gather at end\n        if block_id == 0:\n            for other_ag_stream in self._ag_st:\n                self._completion_st.wait_stream(other_ag_stream)\n            with torch.cuda.stream(self._completion_st):\n                torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _pipeline_step(self):\n        # Call step kernel once per step\n        # Call all-gather once per step\n        with torch.cuda.stream(self._completion_st):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    self._reductions_works[block_id][chunk_id].wait()\n            self.__launch_step_kernel(\n                self._fp32_p,\n                self._fp16_p,\n                self._fp32_m,\n                self._fp32_v,\n                self._fp16_g)\n            torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _flatten_grad_mt(self, scale):\n        if self._flat_mt and len(self._grads) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads)),\n                    scale)\n            self._grads = []\n\n    def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):\n        # handle overlapped reductions\n        if self._flat_mt:\n            self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )\n        else:\n            torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])\n        self._grads_generated[param_i]=True\n        if not self._last_step:\n            if self._overlap_reductions:\n                flush_block = self._get_flush_block()\n                while flush_block:\n                    block_id = flush_block[0] // self._block_size\n                    self._pipeline_block_reductions(block_id)\n                    if self._full_pipeline:\n                        self._pipeline_block_step(block_id)\n                    flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def has_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Clears the overflow flag.\n        \"\"\"\n        has_overflow = self._has_overflow\n        self._has_overflow = False\n        return has_overflow\n\n    @property\n    def peek_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Does not clear overflow flag.\n        \"\"\"\n        return self._has_overflow\n\n    def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):\n        \"\"\"Strided check for overflow.\n        You can get status by calling has_overflow.\n        \"\"\"\n        if start >= 0 and start < end:\n            out_p = output_params[start:end]\n        else:\n            out_p = output_params\n        fused_adam_cuda.strided_check_finite(self._overflow_buf,\n                out_p,\n                stride,\n                1 if clear else 0)\n        self._has_overflow = False if self._overflow_buf.item() == 0 else True\n        return self._has_overflow\n\n    @property\n    def L2_grad_norm(self):\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n            return self._L2_grad_norm\n        else:\n            return None\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, grad_generated in enumerate(self._grads_generated):\n                if not grad_generated:\n                    grad_info = self._grads_info[param_i]\n                    param_offset = grad_info[\"param_offset\"]\n                    param_size = grad_info[\"param_grads_size\"]\n                    self._flat_grads[param_offset:param_offset+param_size].zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            for block_id in range(self._num_blocks-1,-1,-1):\n                self._pipeline_block_reductions(block_id)\n\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def revert_step(self):\n        \"\"\"Revert effect of previously calling partial_step.\n        \"\"\"\n        # Call undo kernel once per step\n        combined_scale = self._global_scale\n        if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        bias_correction = 1 if self._param_group['bias_correction'] else 0\n        beta1, beta2 = self._param_group['betas']\n        fused_adam_cuda.maybe_adam_undo(\n                    torch.empty([0]),\n                    self._fp32_p,\n                    self._fp32_m,\n                    self._fp32_v,\n                    self._fp16_g,\n                    self._param_group['lr'],\n                    beta1,\n                    beta2,\n                    self._param_group['eps'],\n                    combined_scale,\n                    self._param_state['step']+1,\n                    self.eps_mode,\n                    bias_correction,\n                    self._param_group['weight_decay'])\n\n    def step(self, closure=None, skip_overflow_check=False):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if self._last_step or not self._overlap_reductions or not self._full_pipeline:\n            self._pipeline_step()\n\n        with torch.cuda.stream(self._completion_st):\n            # Check for overflow\n            # Store state for loss scaler calculation\n            has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)\n            if has_overflow:\n                self.revert_step()\n            else:\n                # Copy self._new_params to model params\n                for p in self._model_params: self.state[p]['step'] += 1\n                multi_tensor_applier(\n                        fused_adam_cuda.maybe_cast_mt,\n                        self._overflow_buf,\n                        self._packed_flat_to_model_params)\n\n        torch.cuda.current_stream().wait_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        return loss\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/distributed_fused_adam_v3.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass DistributedFusedAdamV3(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        use_mt (boolean, optional): use multi tensor apply for lower launch\n            latency. (default: False)\n        overlap_reductions(boolean, optional): whether to overlap reductions\n            with bprop (default: True)\n        num_prestats (integer, optional): number of fp64 stats that will be\n            reduced during first fp16 gradient reduction block. \n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True,\n                 betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,\n                 weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,\n                 amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,\n                 compute_L2_grad_norm=False, distributed_weight_update=0,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,\n                 dwu_num_ag_pg=0, revert_method=1, flat_mt=False,\n                 dwu_num_chunks=4, predivide=True, e5m2_allgather=False,\n                 do_not_flatten_model=False):\n        global fused_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n\n        self._amp_scale_adjustment = amp_scale_adjustment\n\n        if use_mt:\n            raise RuntimeError('DistributedFusedAdam does not support use_mt.')\n        if amsgrad:\n            raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')\n\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(DistributedFusedAdamV3, self).__init__(params, defaults)\n        self.eps_mode = 0 if  eps_inside_sqrt else 1\n\n        self._overflow_buf = torch.cuda.IntTensor([0])\n\n        assert (len(self.param_groups) == 1), \"More than one parameter group is not supported.\"\n\n        # Way to revert a step\n        # 3 -> undo kernel + double buffer (debug, print norm of difference)\n        # 2 -> double buffer fp32 parameters\n        # 1 -> undo kernel\n        self._revert_method = revert_method\n        if self._revert_method > 1:\n            print(\"revert_method -> double buffer fp32 parameters, will consume more memory\")\n\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._predivide = predivide\n        self._e5m2_allgather = e5m2_allgather\n        self._do_not_flatten_model = do_not_flatten_model\n        self._full_pipeline = full_pipeline\n        self._L2_grad_norm = None\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._world_size = torch.distributed.get_world_size()\n        self._num_groups = self._world_size // self._group_size\n        self._rank_in_group = torch.distributed.get_rank() % self._group_size\n\n        p_offset = 0\n        p_i = 0\n        self._param_state = None\n        self._model_params = []\n        self._grads_info = []\n        self._grad_accs = []\n        for group in self.param_groups:\n            self._param_group = group\n            prev = None\n            for p in group['params']:\n                torch.distributed.broadcast(p,0)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                state = self.state[p]\n                if len(state) == 0:\n                    state['step'] = 0\n                if self._param_state is None:\n                    self._param_state = state\n                p_grads_size = p.numel()\n                def wrapper(param, param_i, param_grads_size, param_offset):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n                wrapper(p, p_i, p_grads_size, p_offset)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._grads_info)\n        self._flat_mt = flat_mt\n        self._grads = []\n        self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._shard_size = self._total_param_size // self._group_size\n        print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        print(self._low_param_i)\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._flat_params = torch.zeros_like(self._flat_grads)\n\n        def _flat_split(flat):\n            def __flat_blockify(flat):\n                return [flat[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __flat_shardify(flat):\n                return [flat[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            return __flat_blockify(flat), __flat_shardify(flat)\n        self._flat_grads_blocks, self._flat_grads_shards = _flat_split(self._flat_grads)\n        self._flat_params_blocks, self._flat_params_shards = _flat_split(self._flat_params)\n\n        # master params\n        self._fp32_p = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_m = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_v = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')\n\n        # copy model params to flat_params and set_ model params to flat_params.\n        self._individual_flat_grads = []\n        with torch.no_grad():\n            for p, grads_info in zip(self._model_params, self._grads_info):\n                start = grads_info[\"param_offset\"]\n                end = start + grads_info[\"param_grads_size\"]\n                flat_p = self._flat_params[start:end].view_as(p)\n                flat_p.copy_(p)\n                p.set_(flat_p)\n                flat_grad = self._flat_grads[start:end]\n                self._individual_flat_grads.append(flat_grad)\n        self._fp32_p.copy_(self._flat_params_shards[self._rank_in_group].float())\n\n        self._dwu_st = torch.cuda.Stream()\n        self._l2_grad_norm_st = torch.cuda.Stream()\n        for group_i in range(self._num_groups):\n            ranks = [group_i*self._group_size+local_rank for local_rank in range(self._group_size)]\n            pg = torch.distributed.new_group(ranks=ranks)\n            if torch.distributed.get_rank() in ranks:\n                self._ag_pg = pg\n                torch.distributed.all_reduce(self._overflow_buf, group=self._ag_pg)\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n    @property\n    def has_overflow(self):\n        return True if not self.L2_grad_norm is None and not math.isfinite(self.L2_grad_norm) else False\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def __launch_step_kernel(self, p, p_copy, m, v, g):\n        combined_scale = self._global_scale\n        if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        bias_correction = 1 if self._param_group['bias_correction'] else 0\n        beta1, beta2 = self._param_group['betas']\n        fused_adam_cuda.reversible_adam(\n                p, p_copy, m, v, g,\n                self._param_group['lr'],\n                beta1,\n                beta2,\n                self._param_group['eps'],\n                combined_scale,\n                self._param_state['step']+1,\n                self.eps_mode,\n                bias_correction,\n                self._param_group['weight_decay'])\n\n    def _flatten_grad_mt(self, scale):\n        if self._flat_mt and len(self._grads) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads)),\n                    scale)\n            self._grads = []\n\n    def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):\n        # handle overlapped reductions\n        if self._flat_mt:\n            self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )\n        else:\n            torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])\n        self._grads_generated[param_i]=True\n        if not self._last_step and self._overlap_reductions:\n            flush_block = self._get_flush_block()\n            while flush_block:\n                block_id = flush_block[0] // self._block_size\n                self._dwu_st.wait_stream(torch.cuda.current_stream())\n                with torch.cuda.stream(self._dwu_st):\n                    self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n                    torch.distributed.all_reduce(self._flat_grads_blocks[block_id])\n                if block_id == 0:\n                    self._l2_grad_norm_st.wait_stream(self._dwu_st)\n                    with torch.cuda.stream(self._l2_grad_norm_st):\n                        self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()\n                flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def L2_grad_norm(self):\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n        return self._L2_grad_norm\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, flat_grad in enumerate(self._individual_flat_grads):\n                if not self._grads_generated[param_i]:\n                    flat_grad.zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            self._dwu_st.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(self._dwu_st):\n                self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n                torch.distributed.all_reduce(self._flat_grads)\n            self._l2_grad_norm_st.wait_stream(self._dwu_st)\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def step(self, closure=None, skip_overflow_check=False):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        with torch.cuda.stream(self._dwu_st):\n            self.__launch_step_kernel(\n                self._fp32_p,\n                self._flat_params_shards[self._rank_in_group],\n                self._fp32_m,\n                self._fp32_v,\n                self._flat_grads_shards[self._rank_in_group])\n            torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)\n            for p in self._model_params: self.state[p]['step'] += 1\n\n        torch.cuda.current_stream().wait_stream(self._dwu_st)\n\n        return loss\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/distributed_fused_lamb.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nimport torch.distributed.distributed_c10d as c10d\n\nclass DistributedFusedLAMB(torch.optim.Optimizer):\n\n    \"\"\"Implements LAMB algorithm.\n    \n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n    \n    This version of fused LAMB implements 2 fusions.\n      \n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n    \n    :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n        \n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n    \n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n        \n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n    \n    In general, ``opt_level=\"O1\"`` is recommended.\n    \n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n    \n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n        use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0\n            weight decay parameter (default: False)\n        step_supports_amp_scaling(boolean, optional): whether to use customized\n            gradient unscaling logic (default: True)\n    \n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    class AtomicCounter(object):\n        def __init__(self):\n            self.value = 0\n            self.order = []\n            import threading\n            self._lock = threading.Lock()\n\n        def add(self, idx):\n            with self._lock:\n                self.value += 1\n                self.order.append(idx)\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True, grad_averaging=True,\n                 betas=(0.9, 0.999), eps=1e-8, \n                 weight_decay=0., max_grad_norm=0., \n                 adam_w_mode=True, use_nvlamb=False,\n                 step_supports_amp_scaling=True, overlap_reductions=True,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,\n                 dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, \n                 e5m2_allgather=False, verbose=False, clip_after_ar=True):\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging,\n                        max_grad_norm=max_grad_norm)\n\n        super(DistributedFusedLAMB, self).__init__(params, defaults)\n\n        global fused_adam_cuda, distributed_lamb_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n        distributed_lamb_cuda = importlib.import_module(\"distributed_lamb_cuda\")\n\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        self._has_overflow = False\n        self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term\n        self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights\n        import amp_C\n        self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n\n        self._grad_averaging = grad_averaging\n        self._adam_w_mode = 1 if adam_w_mode else 0\n        self._use_nvlamb = use_nvlamb\n        self._step_supports_amp_scaling = step_supports_amp_scaling\n        self._is_accumulation_step = False\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._num_chunks = dwu_num_chunks\n        self._e5m2_allgather = e5m2_allgather\n        self._verbose = verbose\n        self._clip_after_ar = clip_after_ar\n        self._L2_grad_norm = None\n        \n        self._current_process_group = c10d._get_default_group()\n        self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._world_size = torch.distributed.get_world_size()\n        self._num_groups = self._world_size // self._group_size\n        self._rank_in_group = torch.distributed.get_rank() % self._group_size\n\n        self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')\n\n        self._resume_from_checkpoint = False\n        self._step = torch.cuda.IntTensor([0])\n\n        # Master weight, moment, gradient buffers\n        self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n        self._num_rs_pg = dwu_num_rs_pg\n        self._num_ar_pg = dwu_num_ar_pg\n        self._num_ag_pg = dwu_num_ag_pg\n        if self._num_groups > 1:\n            self._ar_pg = []\n            for dev_i in range(self._group_size):\n                ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]\n                for i in range(self._num_ar_pg):\n                    if self._verbose:\n                        print(f\"creating new group {i}: {ranks}\")\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:\n                        if self._verbose:\n                            print(f\"group {i}: init barrier (device: {torch.cuda.current_device()})\")\n                        torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])\n                    if self._verbose:\n                        print(f\"created new group {i}\")\n\n                    if torch.distributed.get_rank() in ranks:\n                        self._ar_pg.append(grp)\n            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            #for ar_pg in self._ar_pg:\n            #    torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)\n        rs_ranks = []\n        for group_i in range(self._num_groups):\n            rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])\n        self._rs_pg = []\n        for group_i in range(self._num_groups):\n            ranks = rs_ranks[group_i]\n            for i in range(self._num_rs_pg):\n                grp = torch.distributed.new_group(ranks=ranks)\n                if torch.distributed.get_rank() in ranks:\n                    self._rs_pg.append(grp)\n            l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)\n            if torch.distributed.get_rank() in ranks:\n                self._l2_grad_norm_pg = l2_grad_norm_pg\n                #torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)\n        self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]\n        #for rs_pg in self._rs_pg:\n        #    torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)\n        if self._num_ag_pg == 0:\n            self._ag_pg = self._rs_pg\n            self._ag_st = self._rs_st\n            self._num_ag_pg = self._num_rs_pg\n        else:\n            self._ag_pg = []\n            for group_i in range(self._num_groups):\n                ranks = rs_ranks[group_i]\n                for i in range(self._num_ag_pg):\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if torch.distributed.get_rank() in ranks:\n                        self._ag_pg.append(grp)\n            self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n            #for ag_pg in self._ag_pg:\n            #    torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)\n        self._l2_grad_norm_st = torch.cuda.Stream()\n        self._completion_st = torch.cuda.Stream()\n        self._step.record_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        self._one = torch.cuda.IntTensor([1])\n\n        self._first_step = True\n        self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False\n        self._param_order = self.AtomicCounter()\n\n    def _lazy_init_stage1(self):\n        if self._lazy_init_stage1_done: return\n\n        p_offset = 0\n        p_i = 0\n        self._model_params = []\n        self._grad_accs = []\n        self._group_properties = []\n        for group in self.param_groups:\n            prev = None\n            beta1, beta2 = group['betas']\n            beta3 = 1.0 - beta1 if self._grad_averaging else 1.0\n            bias_correction = 1 if group['bias_correction'] else 0\n            eps = group['eps']\n            weight_decay = group['weight_decay']\n            for p in group['params']:\n                torch.distributed.broadcast(p, 0)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                self._group_properties.append((\n                    weight_decay,\n                    bias_correction,\n                    beta1,\n                    beta2,\n                    beta3,\n                    eps\n                    ))\n                p_grads_size = p.numel()\n                def wrapper(param, param_i):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        if self._first_step:\n                            # first time\n                            self._param_order.add(param_i)\n                        else:\n                            idx = self._param_order.order.index(param_i)\n                            self._do_overlapped_reduction(idx, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                wrapper(p, p_i)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._model_params)\n        self._grads_fp16, self._grads_fp32 = [], []\n        if self._overlap_reductions:\n            self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._chunk_size = self._block_size // self._num_chunks\n        self._shard_size = self._chunk_size // self._group_size\n        #print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size\n        # initialize master weights, moments buffers if not loaded from checkpoint\n        if self._fp32_p is None:\n            self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        # FIXME: Rethink fp16 label since it's either uint8 or fp16\n        self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')\n\n        def _flat_split(p):\n            def __blockify(p):\n                return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            def __shardify(p):\n                return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            list_of_blocks = __blockify(self._flat_grads)\n            list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]\n            list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]\n            return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards\n        self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)\n        def _full_packed_split(p):\n            def __shardify(p):\n                return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]\n            def __blockify(p):\n                return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_mega_shards = __shardify(p)\n            list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]\n            list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]\n            return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks\n        self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)\n        def _packed_split(p):\n            def __packed_blockify(p):\n                packed_block_size = self._num_chunks*self._shard_size\n                return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]\n            def __packed_chunkify(p):\n                # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_blocks = __packed_blockify(p)\n            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]\n            return list_of_blocks, list_of_list_of_chunks\n        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)\n        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)\n        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)\n        self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)\n        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)\n        self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)\n\n        self._lazy_init_stage1_done = True\n\n    def _lazy_init_stage2(self):\n        if self._lazy_init_stage2_done: return\n\n        self._param_order.order.reverse()\n\n        # re-order model_params, grad_accs, group_properties lists\n        self._model_params = [self._model_params[i] for i in self._param_order.order]\n        self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]\n        self._group_properties = [self._group_properties[i] for i in self._param_order.order]\n\n        # re-collect grads info (size, offset) after ordering\n        prev = None\n        p_offset = 0\n        self._grads_info = []\n        self._individual_flat_grads = []\n        for i, p in enumerate(self._model_params):\n            p_grads_size = p.numel()\n            self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n            self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))\n            # for the first iteration\n            self._do_overlapped_reduction(i, p)\n            p_offset += p_grads_size\n            # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n            # RNN is one example of consecutive parameters:\n            # (weight_ih, weight_hh, bias_ih, bias_hh)\n            if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                p_offset = ((p_offset + 63) // 64) * 64\n            prev = p\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        #print(\"self._low_param_i\", self._low_param_i)\n\n        # This paragraph does two things:\n        # 1) Copy model parameters into master buffer\n        # 2) Create tensor lists for unpacking new parameter tensor after all-gather\n        self._packed_flat_to_model_params_fp16 = []\n        self._packed_flat_to_model_params_fp32 = []\n        self._model_params_num = len(self._model_params)\n        self._contrib_tensor_list = []\n        self._contrib_min_param_i, self._contrib_max_param_i = -1, -1\n        self._contrib_update_frag_for_norm = []\n        self._contrib_model_param_for_norm_fp16 = []\n        self._contrib_model_param_for_norm_fp32 = []\n        self._contrib_model_param_for_norm_is_fp16 = []\n        self._model_param_is_contrib = []\n        self._contrib_group_properties = []\n        for shard_id in range(self._group_size):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size\n                    flat_shard_end = flat_shard_start + self._shard_size\n                    for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)):\n                        flat_grad_start = grads_info[\"param_offset\"]\n                        flat_grad_end = flat_grad_start + grads_info[\"param_grads_size\"]\n                        clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)\n                        clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)\n                        if clipped_start < clipped_end:\n                            grad_offset = clipped_start - flat_grad_start\n                            grad_length = clipped_end - clipped_start\n                            shard_offset = clipped_start - flat_shard_start\n                            model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]\n                            new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                            if model_param_fragment.dtype == torch.float16:\n                                self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )\n                            else:\n                                self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )\n                            if shard_id == self._rank_in_group:\n                                self._model_param_is_contrib.append(param_i)\n                                # copy model parameters into master buffer\n                                master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                #print(\"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s\" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))\n                                if not self._resume_from_checkpoint:\n                                    master_param_fragment.copy_(model_param_fragment)\n                                self._contrib_group_properties.append(group_props)\n                                self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy\n                                self._contrib_update_frag_for_norm.append(opti_state_u_fragment)\n                                if p.dtype == torch.float16:\n                                    self._contrib_model_param_for_norm_fp16.append(p)\n                                else:\n                                    self._contrib_model_param_for_norm_fp32.append(p)\n                                self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False)\n                                if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i\n                                self._contrib_max_param_i = param_i\n        self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)\n        if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None\n        if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None\n        self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')\n        self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')\n        self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')\n\n        p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))\n        self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]\n        self._contrib_update_weights_tensor_list = [u, p, p_copy]\n\n        math_type = self._fp32_u.dtype\n        decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties))\n        self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')\n        self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')\n        self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda')\n        self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')\n        self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')\n        self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')\n\n        self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None\n        self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None\n\n        self._lazy_init_stage2_done = True\n\n        self.complete_reductions()\n        self._first_step = False\n\n    def set_is_accumulation_step(self, is_accumulation_step):\n        self._is_accumulation_step = is_accumulation_step\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def _pipeline_block_reductions(self, block_id):\n        if self._clip_after_ar:\n            self._flatten_grad_mt(1.0/self._world_size)\n\n            # Reduction within each node\n            # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n            # The output format is the same as the fp32 master parameters\n            works = [None]*self._num_chunks\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]\n                rs_stream.wait_stream(torch.cuda.current_stream())\n                with torch.cuda.stream(rs_stream):\n                    works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)\n\n            # Reduction across nodes for each rank\n            if self._num_groups > 1:\n                for chunk_id in range(self._num_chunks):\n                    glob_chunk_id = block_id * self._num_chunks + chunk_id\n                    ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                    with torch.cuda.stream(ar_stream):\n                        works[chunk_id].wait()\n                        works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n            self._reductions_works[block_id] = works\n\n            # Compute L2 grad norm\n            if block_id == 0:\n                with torch.cuda.stream(self._l2_grad_norm_st):\n                    for block_id in range(self._num_blocks):\n                        for chunk_id in range(self._num_chunks):\n                            self._reductions_works[block_id][chunk_id].wait()\n                    # Since the packed format is contiguous after reductions, only one norm is needed\n                    l2_grad_norm_sq = torch.empty([1], device='cuda')\n                    l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2\n                    torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)\n                    self._L2_grad_norm = l2_grad_norm_sq.sqrt()\n        else:\n            # Copy model grads to flat grads buffer\n            self._flatten_grad_mt(1.0)\n\n            # Compute L2 grad norm\n            self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n            # Apply clipping & pre-reduction scaling on grads\n            loss_scale = self.global_scale\n            max_grad_norm = loss_scale*self.defaults['max_grad_norm']\n            coeff = max_grad_norm /(1e-6+self.L2_grad_norm)\n            coeff = (coeff>1) * self._one + (coeff<=1) * coeff\n            tmp = torch.cat(((self._one), (coeff)))\n            index = (coeff+1>coeff).int()\n            scale = tmp.index_select(0, index).half()/self._world_size\n            self._flat_grads.mul_(scale)\n\n            # Reduction within each node\n            # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n            # The output format is the same as the fp32 master parameters\n            works = [None]*self._num_chunks\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]\n                rs_stream.wait_stream(torch.cuda.current_stream())\n                rs_stream.wait_stream(self._l2_grad_norm_st)\n                with torch.cuda.stream(rs_stream):\n                    works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)\n\n            # Reduction across nodes for each rank\n            if self._num_groups > 1:\n                for chunk_id in range(self._num_chunks):\n                    glob_chunk_id = block_id * self._num_chunks + chunk_id\n                    ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                    with torch.cuda.stream(ar_stream):\n                        works[chunk_id].wait()\n                        works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n            self._reductions_works[block_id] = works\n\n            if block_id == 0:\n                for block_id in range(self._num_blocks):\n                    for chunk_id in range(self._num_chunks):\n                        self._reductions_works[block_id][chunk_id].wait()\n\n    def __compute_contrib_param_norm(self):\n        if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:\n            gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]\n            gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]\n            gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')\n            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)\n            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)\n        elif self._contrib_model_param_for_norm_fp16 is not None:\n            gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]\n        elif self._contrib_model_param_for_norm_fp32 is not None:\n            gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]\n        return gnorm\n\n    def __compute_contrib_update_norm(self):\n        l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')\n        local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2\n        l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)\n        torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])\n        l2_norm = torch.sqrt(l2_norm)\n        return l2_norm\n\n    def _pipeline_step(self):\n        global_scale = self.global_scale\n        # if clip before ar, set max_grad_norm to 0\n        max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar\n        self._completion_st.wait_stream(self._l2_grad_norm_st)\n        global_grad_norm = self.L2_grad_norm\n\n        # check global_grad_norm and fill overflow_buf\n        is_finite = (global_grad_norm + 1 > global_grad_norm).int()\n        self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1\n        torch.distributed.all_reduce(is_finite,\n                                     op=torch.distributed.ReduceOp.MIN,\n                                     group=self._current_process_group)\n        torch.distributed.all_reduce(self._overflow_buf,\n                                     op=torch.distributed.ReduceOp.MAX,\n                                     group=self._current_process_group)\n\n        # increment step counter if no overflow\n        self._step += is_finite\n        self._completion_st.wait_stream(torch.cuda.current_stream())\n        self._completion_st.wait_stream(self._l2_grad_norm_st)\n\n        # Call step kernel once per step\n        # Call all-gather once per step\n        with torch.cuda.stream(self._completion_st):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    self._reductions_works[block_id][chunk_id].wait()\n            param_norm = self.__compute_contrib_param_norm()\n            multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,\n                    self._overflow_buf,\n                    self._contrib_compute_update_term_tensor_list, # g, p, m, v, u\n                    self._contrib_beta1,\n                    self._contrib_beta2,\n                    self._contrib_beta3,\n                    self._contrib_bias_correction,\n                    self._step,\n                    self._contrib_epsilon,\n                    self._adam_w_mode,\n                    self._contrib_weight_decay,\n                    global_scale,\n                    global_grad_norm,\n                    max_grad_norm)\n            upd_norm = self.__compute_contrib_update_norm()\n            multi_tensor_applier(self.multi_tensor_lamb_update_weights,\n                    self._overflow_buf,\n                    self._contrib_update_weights_tensor_list, # u, p, p_copy\n                    param_norm,\n                    upd_norm,\n                    self._offsets,\n                    self._lr,\n                    self._contrib_weight_decay,\n                    global_grad_norm,\n                    self._use_nvlamb)\n            torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _flatten_grad_mt(self, scale):\n        if len(self._grads_fp16) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp16)),\n                    scale)\n            self._grads_fp16 = []\n        if len(self._grads_fp32) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp32)),\n                    scale)\n            self._grads_fp32 = []\n\n    def _do_overlapped_reduction(self, param_i, param):\n        if not self._is_accumulation_step:\n            # handle overlapped reductions\n            if param.dtype == torch.float16:\n                self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )\n            else:\n                self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )\n            self._grads_generated[param_i]=True\n            if not self._first_step and not self._last_step:\n                if self._overlap_reductions:\n                    flush_block = self._get_flush_block()\n                    while flush_block:\n                        block_id = flush_block[0] // self._block_size\n                        self._pipeline_block_reductions(block_id)\n                        flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def L2_grad_norm(self):\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n        return self._L2_grad_norm\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, grad_generated in enumerate(self._grads_generated):\n                if not grad_generated:\n                    grad_info = self._grads_info[param_i]\n                    param_offset = grad_info[\"param_offset\"]\n                    param_size = grad_info[\"param_grads_size\"]\n                    self._flat_grads[param_offset:param_offset+param_size].zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._first_step or self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            for block_id in range(self._num_blocks-1,-1,-1):\n                self._pipeline_block_reductions(block_id)\n\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def step(self, closure=None, grad_scaler=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        self._pipeline_step()\n\n        if grad_scaler is not None:\n            found_inf = self._overflow_buf.float()\n            optimizer_state = grad_scaler._per_optimizer_states[id(self)]\n            current_device = torch.device('cuda', torch.cuda.current_device())\n            optimizer_state[\"found_inf_per_device\"][current_device] = found_inf\n\n        self._completion_st.wait_stream(torch.cuda.current_stream())\n\n        with torch.cuda.stream(self._completion_st):\n            # Copy self._new_params to model params\n            with torch.no_grad():\n                if self._packed_flat_to_model_params_fp16 is not None:\n                    multi_tensor_applier(\n                            fused_adam_cuda.maybe_cast_mt,\n                            self._overflow_buf,\n                            self._packed_flat_to_model_params_fp16)\n                if self._packed_flat_to_model_params_fp32 is not None:\n                    multi_tensor_applier(\n                            fused_adam_cuda.maybe_cast_mt,\n                            self._overflow_buf,\n                            self._packed_flat_to_model_params_fp32)\n\n        torch.cuda.current_stream().wait_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        return loss\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        # save step, master weights and first/second moments\n        state_dict = {}\n        state_dict['step'] = self._step\n        state_dict['fp32_p'] = self._fp32_p\n        state_dict['fp32_m'] = self._fp32_m\n        state_dict['fp32_v'] = self._fp32_v\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``optimizer.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # restore step, master weights and first/second moments\n        self._step = state_dict['step']\n        self._fp32_p = state_dict['fp32_p'].to(device=\"cuda\")\n        self._fp32_m = state_dict['fp32_m'].to(device=\"cuda\")\n        self._fp32_v = state_dict['fp32_v'].to(device=\"cuda\")\n        self._resume_from_checkpoint = True\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/fp16_optimizer.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FP16_Optimizer(object):\n    \"\"\"\n    :class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.\n    Designed only to wrap apex.contrib.optimizers.FusedAdam, FusedSGD.\n    Refer to apex.fp16_utils documents for more information.\n    Example::\n        model = torch.nn.Linear(D_in, D_out).cuda().half()\n        optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())\n        optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n        ...\n        # loss.backward() becomes:\n        optimizer.backward(loss)\n        ...\n    Example with dynamic loss scaling::\n        ...\n        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)\n                                   # optional arg to control dynamic loss scaling behavior\n                                   # dynamic_loss_args={'scale_window' : 500})\n                                   # Usually, dynamic_loss_args is not necessary.\n    \"\"\"\n\n    def __init__(self,\n                 init_optimizer,\n                 static_loss_scale=1.0,\n                 dynamic_loss_scale=False,\n                 dynamic_loss_args=None,\n                 verbose=True):\n\n        print(\"\\nThis fp16_optimizer is designed to only work with apex.contrib.optimizers.*\")\n        print(\"To update, use updated optimizers with AMP.\")\n        # The fused optimizer does all the work. We need this layer for two reason:\n        # 1. maintain same user API from apex.fp16_utils\n        # 2. keep common stuff here in case we need to add new fused optimizer later\n\n        if not torch.cuda.is_available:\n            raise SystemError(\"Cannot use fp16 without CUDA.\")\n        self.optimizer = init_optimizer\n\n        self.fp16_groups = [] # model params\n        self.fp32_groups = [] # master weights\n\n        # iterate over param_groups\n        for param_group in self.optimizer.param_groups:\n            fp16_group = []\n            fp32_group = []\n            for p in param_group['params']:\n                fp16_group.append(p)\n                fp32_group.append(p.clone().float().detach())\n            self.fp16_groups.append(fp16_group)\n            self.fp32_groups.append(fp32_group)\n            param_group['params'] = fp32_group\n\n        if multi_tensor_applier.available:\n            import amp_C\n            self.overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm\n        else:\n            raise RuntimeError('FP16_Optimizer requires cuda extensions')\n\n        # we may have a way of fusing dynamic scale. Do not support for now\n        if dynamic_loss_scale:\n            if dynamic_loss_args is not None:\n                raise SystemError(\"Do not support dynamic loss scale args for now.\")\n            self.dynamic_loss_scale = True\n            self.cur_scale = 2**16\n            self.cur_iter = 0\n            self.last_overflow_iter = -1\n            self.scale_factor = 2\n            self.scale_window = 1000\n        else:\n            self.dynamic_loss_scale = False\n            self.cur_iter = 0\n            self.cur_scale = static_loss_scale\n        self.verbose = verbose\n\n    def zero_grad(self, set_grads_to_None=True):\n        \"\"\"\n        Zero FP16 parameter grads.\n        \"\"\"\n        # FP32 grad should never exist.\n        # For speed, set model fp16 grad to None by default\n        for group in self.fp16_groups:\n            for p in group:\n                if set_grads_to_None:\n                    p.grad = None\n                else:\n                    if p.grad is not None:\n                        p.grad.detach_()\n                        p.grad.zero_()\n\n    def step(self, closure=None):\n        \"\"\"\n        Not supporting closure.\n        \"\"\"\n        fp16_grads = []\n        norm_groups = []\n        skip = False\n\n        for group in self.fp16_groups:\n            fp16_grad = []\n            for i, p in enumerate(group):\n                fp16_grad.append(p.grad)\n            fp16_grads.append(fp16_grad)\n        \n        # nan check\n        self.overflow_buf.zero_()\n        for fp16_grad in fp16_grads:\n            if len(fp16_grad) > 0:\n                norm, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,\n                                                             self.overflow_buf,\n                                                             [fp16_grad], True)\n                norm_groups.append(norm)\n                if self.overflow_buf.item() != 0:\n                    skip = True\n\n        if skip:\n            self._update_scale(skip)\n            return\n\n        # norm is in fact norm*cur_scale\n        self.optimizer.step(grads=fp16_grads,\n                            output_params=self.fp16_groups,\n                            scale=self.cur_scale,\n                            grad_norms=norm_groups)\n\n        self._update_scale(False)\n        return\n\n    def backward(self, loss):\n        \"\"\"\n        :attr:`backward` performs the following steps:\n        1. fp32_loss = loss.float()\n        2. scaled_loss = fp32_loss*loss_scale\n        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves\n        \"\"\"\n        scaled_loss = (loss.float()) * self.cur_scale\n        scaled_loss.backward()\n\n    def _update_scale(self, skip):\n        if self.dynamic_loss_scale:\n            if skip:\n                if self.verbose:\n                    print(\"\\nGrad overflow on iteration\", self.cur_iter)\n                    print(\"Using dynamic loss scale of\", self.cur_scale)\n                self.cur_scale = max(self.cur_scale/self.scale_factor, 1)\n                self.last_overflow_iter = self.cur_iter\n            else:\n                if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:\n                    self.cur_scale *= self.scale_factor\n        else:\n            if skip:\n                print(\"\\nGrad overflow on iteration\", self.cur_iter)\n                print(\"Using static loss scale of\", self.cur_scale)\n        self.cur_iter +=1\n        return\n\n    # Promote state so it can be retrieved or set via \"fp16_optimizer_instance.state\"\n    def _get_state(self):\n        return self.optimizer.state\n\n    def _set_state(self, value):\n        self.optimizer.state = value\n\n    state = property(_get_state, _set_state)\n\n    # Promote param_groups so it can be retrieved or set via \"fp16_optimizer_instance.param_groups\"\n    # (for example, to adjust the learning rate)\n    def _get_param_groups(self):\n        return self.optimizer.param_groups\n\n    def _set_param_groups(self, value):\n        self.optimizer.param_groups = value\n\n    param_groups = property(_get_param_groups, _set_param_groups)\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.\n        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict\n        of the contained Pytorch optimizer.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        state_dict = {}\n        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale\n        state_dict['cur_scale'] = self.cur_scale\n        state_dict['cur_iter'] = self.cur_iter\n        if state_dict['dynamic_loss_scale']:\n            state_dict['last_overflow_iter'] = self.last_overflow_iter\n            state_dict['scale_factor'] = self.scale_factor\n            state_dict['scale_window'] = self.scale_window\n        state_dict['optimizer_state_dict'] = self.optimizer.state_dict()\n        state_dict['fp32_groups'] = self.fp32_groups\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``fp16_optimizer_instance.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # I think it should actually be ok to reload the optimizer before the model.\n        self.dynamic_loss_scale = state_dict['dynamic_loss_scale']\n        self.cur_scale = state_dict['cur_scale']\n        self.cur_iter = state_dict['cur_iter']\n        if state_dict['dynamic_loss_scale']:\n            self.last_overflow_iter = state_dict['last_overflow_iter']\n            self.scale_factor = state_dict['scale_factor']\n            self.scale_window = state_dict['scale_window']\n        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])\n        # At this point, the optimizer's references to the model's fp32 parameters are up to date.\n        # The optimizer's hyperparameters and internal buffers are also up to date.\n        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still\n        # out of date.  There are two options.\n        # 1:  Refresh the master params from the model's fp16 params.\n        # This requires less storage but incurs precision loss.\n        # 2:  Save and restore the fp32 master copies separately.\n        # We choose option 2.\n        #\n        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device\n        # of their associated parameters, because it's possible those buffers might not exist yet in\n        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been\n        # constructed in the same way as the one whose state_dict we are loading, the same master params\n        # are guaranteed to exist, so we can just copy_() from the saved master params.\n        for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']):\n            for _current, _saved in zip(current, saved):\n                _current.data.copy_(_saved.data)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/fused_adam.py",
    "content": "import types\nimport torch\nimport importlib\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedAdam(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        use_mt (boolean, optional): use multi tensor apply for lower launch\n            latency. (default: False)\n\n    .. _Adam - A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True,\n                 betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,\n                 weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,\n                 amp_scale_adjustment=1.0):\n        global fused_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n\n        self._use_multi_tensor = False\n        if use_mt:\n            if not multi_tensor_applier.available:\n                print(\"Warning:  multi_tensor_applier is unavailable\")\n            else:\n                self._use_multi_tensor = True\n                self._overflow_buf = torch.cuda.IntTensor([0])\n\n        self._amp_scale_adjustment = amp_scale_adjustment\n\n        if amsgrad:\n            raise RuntimeError('FusedAdam does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(FusedAdam, self).__init__(params, defaults)\n        self.eps_mode = 0 if  eps_inside_sqrt else 1\n\n    def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n            grads (list of tensors, optional): weight gradient to use for the\n                optimizer update. If gradients have type torch.half, parameters\n                are expected to be in type torch.float. (default: None)\n            output params (list of tensors, optional): A reduced precision copy\n                of the updated weights written out in addition to the regular\n                updated weights. Have to be of same type as gradients. (default: None)\n            scale (float, optional): factor to divide gradient tensor values\n                by before applying to weights. (default: 1)\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if hasattr(self, \"_amp_stash\"):\n            grads = self._amp_stash.grads\n            output_params = self._amp_stash.output_params\n            scale = self._amp_stash.scale*self._amp_scale_adjustment\n            grad_norms = self._amp_stash.grad_norms\n\n        if grads is None:\n            grads_group = [None]*len(self.param_groups)\n        # backward compatibility\n        # assuming a list/generator of parameter means single group\n        elif isinstance(grads, types.GeneratorType):\n            grads_group = [grads]\n        elif type(grads[0])!=list:\n            grads_group = [grads]\n        else:\n            grads_group = grads\n\n        if output_params is None:\n            output_params_group = [None]*len(self.param_groups)\n        elif isinstance(output_params, types.GeneratorType):\n            output_params_group = [output_params]\n        elif type(output_params[0])!=list:\n            output_params_group = [output_params]\n        else:\n            output_params_group = output_params\n\n        if grad_norms is None:\n            grad_norms = [None]*len(self.param_groups)\n\n        for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):\n            if grads_this_group is None:\n               grads_this_group = [None]*len(group['params'])\n            if output_params_this_group is None:\n               output_params_this_group = [None]*len(group['params'])\n\n            # compute combined scale factor for this group\n            combined_scale = scale\n            if group['max_grad_norm'] > 0:\n                # norm is in fact norm*scale\n                clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']\n                if clip > 1:\n                    combined_scale = clip * scale\n\n            bias_correction = 1 if group['bias_correction'] else 0\n\n            if self._use_multi_tensor:\n                if output_params:\n                    tensorlists = [[],[],[],[],[]]\n                else:\n                    tensorlists = [[],[],[],[]]\n                tensordevice = None\n\n            for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):\n                #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients\n                if p.grad is None and grad is None:\n                    continue\n                if grad is None:\n                    grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param\n                if self._use_multi_tensor:\n                    pl = [p.data, exp_avg, exp_avg_sq, grad]\n                    if output_param is not None:\n                        pl.append(out_p)\n\n                    for tl, t in zip(tensorlists, pl):\n                        tl.append(t)\n\n                    if tensordevice is None:\n                        tensordevice = p.device\n                    elif tensordevice != p.device:\n                        raise RuntimeError('FusedAdam does not support use_mt with tensors on multiple device')\n\n                else:\n                    with torch.cuda.device(p.device):\n                        fused_adam_cuda.adam(p.data,\n                                             out_p,\n                                             exp_avg,\n                                             exp_avg_sq,\n                                             grad,\n                                             group['lr'],\n                                             beta1,\n                                             beta2,\n                                             group['eps'],\n                                             combined_scale,\n                                             state['step'],\n                                             self.eps_mode,\n                                             bias_correction,\n                                             group['weight_decay'])\n\n            if self._use_multi_tensor:\n                with torch.cuda.device(tensordevice):\n                    multi_tensor_applier(\n                        fused_adam_cuda.adam_mt,\n                        self._overflow_buf,\n                        tensorlists,\n                        group['lr'],\n                        beta1,\n                        beta2,\n                        group['eps'],\n                        combined_scale,\n                        state['step'],\n                        self.eps_mode,\n                        bias_correction,\n                        group['weight_decay'])\n\n        return loss\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/fused_lamb.py",
    "content": "import torch\nimport importlib\nimport math\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedLAMB(torch.optim.Optimizer):\n\n    \"\"\"Implements LAMB algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--deprecated_fused_lamb\" ./``.\n\n    This version of fused LAMB implements 2 fusions.\n\n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n\n        opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n\n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,\n                 amsgrad=False, adam_w_mode=True,\n                 grad_averaging=True, set_grad_none=True,\n                 max_grad_norm=1.0):\n        if amsgrad:\n            raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging,\n                        max_grad_norm=max_grad_norm)\n        super(FusedLAMB, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            fused_lamb_cuda = importlib.import_module(\"fused_lamb_cuda\")\n            self.multi_tensor_lamb = fused_lamb_cuda.lamb\n        else:\n            raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions')\n\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedLAMB, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # create separate grad lists for fp32 and fp16 params\n        g_all_32, g_all_16 = [], []\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.dtype == torch.float32:\n                    g_all_32.append(p.grad.data)\n                elif p.dytpe == torch.float16:\n                    g_all_16.append(p.grad.data)\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n        g_norm_32, g_norm_16 = 0.0, 0.0\n        # compute grad norm for two lists\n        if len(g_all_32) > 0:\n            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_32], False)[0].item()\n        if len(g_all_16) > 0:\n            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_16], False)[0].item()\n\n        # blend two grad norms to get global grad norm\n        global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)\n        max_grad_norm = self.defaults['max_grad_norm']\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n            grad_averaging = 1 if group['grad_averaging'] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                    v_16.append(state['exp_avg_sq'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                    v_32.append(state['exp_avg_sq'])\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16, v_16],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm)\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32, v_32],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm)\n\n        return loss\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/optimizers/fused_sgd.py",
    "content": "import types\nimport torch\nfrom torch.optim.optimizer import Optimizer, required\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedSGD(Optimizer):\n    r\"\"\"Implements stochastic gradient descent (optionally with momentum).\n\n    This version of fused SGD implements 2 fusions.\n      * Fusion of the SGD update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.contrib.optimizers.FusedSGD` should be used without AMP.\n   \n    :class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad. \n\n    Nesterov momentum is based on the formula from\n    `On the importance of initialization and momentum in deep learning`__.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float): learning rate\n        momentum (float, optional): momentum factor (default: 0)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        dampening (float, optional): dampening for momentum (default: 0)\n        nesterov (bool, optional): enables Nesterov momentum (default: False)\n\n    Example:\n        model = ...\n        model.half()\n        optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())\n        # wrap with FP16_Optimizer\n        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)\n        optimizer.zero_grad()\n\t...\n        optimizer.backward(loss)\n        optmizer.step()\n\n    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf\n\n    .. note::\n        The implementation of SGD with Momentum/Nesterov subtly differs from\n        Sutskever et. al. and implementations in some other frameworks.\n\n        Considering the specific case of Momentum, the update can be written as\n\n        .. math::\n                  v = \\rho * v + g \\\\\n                  p = p - lr * v\n\n        where p, g, v and :math:`\\rho` denote the parameters, gradient,\n        velocity, and momentum respectively.\n\n        This is in contrast to Sutskever et. al. and\n        other frameworks which employ an update of the form\n\n        .. math::\n             v = \\rho * v + lr * g \\\\\n             p = p - v\n\n        The Nesterov version is analogously modified.\n    \"\"\"\n\n    def __init__(self, params, lr=required, momentum=0, dampening=0,\n                 weight_decay=0, nesterov=False,\n                 wd_after_momentum=False,\n                 materialize_master_grads=True):\n        if lr is not required and lr < 0.0:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if momentum < 0.0:\n            raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n        if weight_decay < 0.0:\n            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n\n        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,\n                        weight_decay=weight_decay, nesterov=nesterov)\n        if nesterov and (momentum <= 0 or dampening != 0):\n            raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n        super(FusedSGD, self).__init__(params, defaults)\n\n        self.wd_after_momentum = wd_after_momentum\n\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_sgd = amp_C.multi_tensor_sgd\n        else:\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD requires cuda extensions')\n\n    def __setstate__(self, state):\n        super(FusedSGD, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('nesterov', False)\n\n    def get_momentums(self, params):\n        momentums = []\n        first_run = True\n        for p in params:\n            param_state = self.state[p]\n            # torch.optim.SGD initializes momentum in the main loop, we have\n            # to do it here, and track whether or not we've done so, so that\n            # momentum application can be skipped in the main kernel.\n            if 'momentum_buffer' not in param_state:\n                first_run = True\n                buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)\n                momentums.append(buf)\n            else:\n                first_run = False\n                momentums.append(param_state['momentum_buffer'])\n        return momentums, first_run\n    \n    def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):\n        \"\"\"Performs a single optimization step.\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n            grads (list of tensors, optional): weight gradient to use for the\n                optimizer update. If gradients have type torch.half, parameters\n                are expected to be in type torch.float. (default: None)\n            output_params (list of tensors, optional): A reduced precision copy\n                of the updated weights written out in addition to the regular\n                updated weights. Have to be of same type as gradients. (default: None)\n            scale (float, optional): factor to divide gradient tensor values\n                by before applying to weights. (default: 1)\n        \"\"\"\n        if hasattr(self, \"_amp_stash\"):\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD should not be used with AMP.')\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if grads is None:\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \\\n\t                       with apex.contrib.optimizers.FP16_Optimizer \\\n\t\t\t       which provides grads.')\n        # backward compatibility\n        # assuming a list/generator of parameter means single group\n        elif isinstance(grads, types.GeneratorType):\n            grads_group = [grads]\n        elif type(grads[0])!=list:\n            grads_group = [grads]\n        else:\n            grads_group = grads\n\n        if output_params is None:\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \\\n                               with apex.contrib.optimizers.FP16_Optimizer \\\n                               which provides output_params.')\n        elif isinstance(output_params, types.GeneratorType):\n            output_params_group = [output_params]\n        elif type(output_params[0])!=list:\n            output_params_group = [output_params]\n        else:\n            output_params_group = output_params\n\n        for group, grads_this_group, output_params_this_group in zip(self.param_groups, \n\t                                                             grads_group, \n                                                                     output_params_group):\n            if grads_this_group is None or output_params_this_group is None: \n                raise RuntimeError('apex.contrib.optimizers.FusedSGD only works \\\n                                    when all parameters require grad.')\n            \n            weight_decay = group['weight_decay']\n            momentum = group['momentum']\n            dampening = group['dampening']\n            nesterov = group['nesterov']\n            lr = group['lr']\n\n            first_runs = [True, True]\n            \n            # output_params_this_group: original weights (either fp16 or fp32)\n            # group['params']: master weights (fp32)\n\n            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy\n            # fp32, fp32, fp32, No\n            fp32_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float32]\n            fp32_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float32]\n            fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n            fp32_set = [fp32_grads, fp32_params, fp32_momentums]\n\n            # fp16, fp32, fp32, Yes\n            fp16_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float16]\n            fp32_from_fp16_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]\n            fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)\n            fp16_params = [p1 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]\n            fp16_set = [fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_params]\n\n            launch_sets = [fp16_set, fp32_set]\n\n            for launch_set, first_run in zip(launch_sets, first_runs):\n                assert len(launch_set[0]) == len(launch_set[1])\n                assert len(launch_set[0]) == len(launch_set[2])\n                if len(launch_set[0]) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_sgd,\n                        self._dummy_overflow_buf,\n                        launch_set,\n                        weight_decay,\n                        momentum,\n                        dampening,\n                        lr,\n                        nesterov,\n                        first_run,\n                        self.wd_after_momentum,\n                        1.0/scale)\n\n        return loss\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/README.md",
    "content": "# Introduction to ASP\n\nThis serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.\n\n## Importing ASP\n```\nfrom apex.contrib.sparsity import ASP\n```\n\n## Initializing ASP\n\nApart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:\n```\nASP.prune_trained_model(model, optimizer)\n```\n\nIn the context of a typical PyTorch training loop, it might look like this:\n```\nASP.prune_trained_model(model, optimizer)\n\nx, y = DataLoader(args)\nfor epoch in range(epochs):\n    y_pred = model(x)\n    loss = loss_function(y_pred, y)\n    loss.backward()\n    optimizer.step()\n\ntorch.save(...)\n```\nThe `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. \n\n## Generate a Sparse Network\n\nThe following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.\n\n```\n(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.\n(2) Fine-tune  the  pruned  model  with  optimization  method  and  hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.\n(3) (If required) Quantize the model.\n```\n\nIn code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).\n\n```\n\nmodel = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)\ncriterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model\noptimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model\nlr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model\n\nfrom apex.contrib.sparsity import ASP     \nASP.prune_trained_model(model, optimizer) #pruned a trained model\n\nx, y = DataLoader(args)\nfor epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model\n    y_pred = model(x)\n    loss = criterion(y_pred, y)\n    lr_scheduler.step()\n    loss.backward()\n    optimizer.step()\n\ntorch.save(...) # saves the pruned checkpoint with sparsity masks \n```\n\n## Non-Standard Usage\n\nIf your goal is to easily perpare a network for accelerated inference, please follow the recipe above.  However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:\n\n```\nASP.compute_sparse_masks()\n```\n\nA more thorough example can be found in `./test/toy_problem.py`. \n\n\n\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/__init__.py",
    "content": "from .sparse_masklib import create_mask\nfrom .asp import ASP\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/asp.py",
    "content": "import types\nimport torch\nfrom .sparse_masklib import create_mask\n\ntorchvision_imported=True\ntry:\n    import torchvision\nexcept ImportError:\n    print(\"[ASP][Warning] torchvision cannot be imported.\")\n    torchvision_imported=False\n\ndef eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):\n    eligible_modules_list = []\n    for name, mod in model.named_modules():\n        if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:\n            if allowed_layer_names is not None and name not in allowed_layer_names:\n                continue\n            eligible_modules_list.append((name, mod))\n    return eligible_modules_list\n\nclass ASP:\n    __model = None\n    __verbosity = 0\n    __optimizer = None\n    __sparse_parameters = []\n    __calculate_mask = None\n\n    @classmethod\n    def init_model_for_pruning(cls, model, mask_calculator=\"m4n2_1d\",\n             verbosity=3,\n             whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], \n             allowed_layer_names=None, disallowed_layer_names=[],\n             allow_recompute_mask=False, custom_layer_dict={}):\n        \"\"\"Call this method to modify your model to take advantage of sparse matrix multiplication.\n        Note that this call alone only augments the model with additional buffers needed for sparse MMA,\n        it does not enable use of sparse MMA. \n\n        If you are starting with a fresh model:\n\n        model = ...\n        ASP.init_model_for_pruning(model, mask_calculator, ...)\n        if (training) ASP.init_optimizer_for_pruning(optimizer)\n        ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.\n\n        If you are starting from a checkpoint:\n\n        model = ...\n        ASP.init_model_for_pruning(model, mask_calculator, ...)\n        torch.load(...)\n        if (training) ASP.init_optimizer_for_pruning(optimizer)\n\n        Arguments:\n          model                    The model\n          mask_calculator          Either callable that computes mask given a tensor OR pattern string for sparse mask lib.\n          verbosity                Integer controling verbosity level.\n                                   0 -> Only errors.\n                                   1 -> Errors and warnings.\n                                   2 -> Errors, warnings and info.\n                                   3 -> Errors, warnings, info and debug.\n          whitelist                Module types approved for sparsity.\n          allowed_layer_names      If not None, only layer names that appear in this list are considered for sparsity.\n          disallowed_layer_names   If not [], only layer names that do not appear in this list are considered for sparsity.\n          allow_recompute_mask     If True, stores pruned values so that dense weights can be restored.\n                                   Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.\n          custom_layer_dict        Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}\n          \n          [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM. \n        \"\"\"\n        assert (cls.__model is None), \"ASP has been initialized already.\"\n        cls.__model = model\n        cls.__verbosity = verbosity\n\n        if isinstance(mask_calculator, str):\n            def create_mask_from_pattern(param):\n                return create_mask(param, mask_calculator).bool()\n            cls.__calculate_mask = create_mask_from_pattern\n        else:\n            cls.__calculate_mask = mask_calculator #user defined function\n\n        # function to extract variables that will be sparsified. \n        # idea is that you will add one of these functions for each module type that can be sparsified.\n        if torchvision_imported:\n            print(\"[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.\")\n            sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}\n        else:\n            sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}\n        if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune\n            sparse_parameter_list.update(custom_layer_dict)\n            whitelist += list(custom_layer_dict.keys())\n\n        for module_type in whitelist:\n            assert (module_type in sparse_parameter_list), \"Module %s :: Don't know how to sparsify module.\" % module.dtype()\n\n        # find all sparse modules, extract sparse parameters and decorate\n        def add_sparse_attributes(module_name, module):\n            sparse_parameters = sparse_parameter_list[type(module)]\n            for p_name, p in module.named_parameters():\n                if p_name in sparse_parameters and p.requires_grad:\n                    # check for NVIDIA's TC compatibility: we check along the horizontal direction\n                    if p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #User defines FP32 and APEX internally uses FP16 math\n                        print(\"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n                        continue\n                    if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C\n                        print(\"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n                        continue\n                    \n                    if cls.__verbosity >= 3:\n                        print(\"[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n                    \n                    mask = torch.ones_like(p).bool()\n                    buffname = p_name.split(\".\")[-1] # buffer names cannot contain \".\"\n                    module.register_buffer('__%s_mma_mask' % buffname, mask)\n                    if allow_recompute_mask:\n                        pruned = torch.zeros_like(p).cpu()\n                        module.register_buffer('__%s_mma_pruned_p' % buffname, pruned)\n                    else:\n                        pruned = None\n                    cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))\n                else:\n                    if cls.__verbosity >= 3:\n                        print(\"[ASP] Not sparsifying %s::%s of size=%s and type=%s\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n\n        for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):\n            add_sparse_attributes(name, sparse_module)\n\n    @classmethod\n    def init_optimizer_for_pruning(cls, optimizer):\n        \"\"\"Call this method to monkey patch optimizer step function so that masks can be applied to\n        gradients and weights during training.\n        You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)\n        \"\"\"\n        assert (cls.__optimizer is None), \"ASP has initialized optimizer already.\"\n        assert (cls.__calculate_mask is not None), \"Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning.\"\n\n        # store pointer to original optimizer step method\n        cls.__optimizer = optimizer\n        cls.__optimizer.__step = optimizer.step\n\n        def __step(opt_self, *args, **kwargs):\n            # prune gradients before step method\n            with torch.no_grad():\n                for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                    if p.grad is not None: #thx pjudd\n                        p.grad.mul_(mask)\n            # call original optimizer step method\n            rval = opt_self.__step(*args, **kwargs)\n            # prune parameters after step method\n            with torch.no_grad():\n                for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                    p.mul_(mask)\n            return rval\n        cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)\n\n    @classmethod\n    def compute_sparse_masks(cls):\n        \"\"\"Call this method to enable sparsity.\n        If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.\n        \"\"\"\n        with torch.no_grad():\n            for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                if mask.sum() < mask.numel(): # when recalculating masks\n                    # restore dense parameter if allow_recompute_mask is enabled\n                    assert (pruned is not None), \"Unable to restore dense parameter because allow_recompute_mask == False\"\n                    p.add_(pruned.cuda())\n\n                mask.set_(cls.__calculate_mask(p))\n\n                if pruned is not None: # stow away pruned weights to cpu\n                    pruned.set_((p * (~mask)).cpu())\n\n                p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights\n                if cls.__verbosity >= 2:\n                    print(\"[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s\" % (100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))\n\n    @classmethod\n    def restore_pruned_weights(cls):\n        \"\"\"Call this method to disable sparsity and restore all weights.\n        This will only work if init(...) was called with allow_recompute=True.\n        \"\"\"\n        with torch.no_grad():\n            for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                if mask.sum() < mask.numel():\n                    assert (pruned is not None), \"Unable to restore dense parameter because allow_recompute_mask == False\"\n                    p.add_(pruned.cuda())\n                    mask.fill_(1)\n                    pruned.zero_()\n                    if cls.__verbosity >= 2:\n                        print(\"[ASP] Disabled sparsity for %s::%s (dense weights restored)\" % (module_name, p_name))\n\n    @classmethod\n    def is_sparsity_enabled(cls):\n        \"\"\"Call this method to determine if sparsity is enabled in the model.\n        The typical use case is right after checkpoint has been loaded.\n        \"\"\"\n        total,sp100,sp50 = 0,0,0\n        for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n            total += 1\n            mask_sum = mask.sum()\n            mask_numel = mask.numel()\n            if mask_sum == mask_numel:\n                sp100 += 1\n            elif mask_sum*2 == mask_numel:\n                sp50 += 1\n\n        assert (total == sp100 or total == sp50), \"Inconsistent model sparsity\"\n        if total == sp100:\n            return False\n        elif total == sp50:\n            return True\n    \n    @classmethod\n    def prune_trained_model(cls, model, optimizer):\n        # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)\n        cls.init_model_for_pruning(model, mask_calculator=\"m4n2_1d\", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)\n        cls.init_optimizer_for_pruning(optimizer)\n        cls.compute_sparse_masks()\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/sparse_masklib.py",
    "content": "import sys\nimport torch\nimport numpy as np\nimport collections\nfrom itertools import permutations\n\n\n\"\"\" compute density (helper fn to compute % NNZs in a tensor) \"\"\"\ndef fill(x):\n    return float(x.nonzero().size(0))/torch.numel(x)\n\n\"\"\" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) \"\"\"\ndef reshape_1d(matrix, m):\n    # If not a nice multiple of m, fill with zeroes.\n    if matrix.shape[1] % m > 0:\n        mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m-matrix.shape[1]%m)).fill_(0)\n        mat[:, :matrix.shape[1]] = matrix\n        shape = mat.shape\n        return mat.view(-1,m),shape\n    else:\n        return matrix.view(-1,m), matrix.shape\n\n\"\"\" return all possible m:n patterns in a 1d vector \"\"\"\nvalid_m4n2_1d_patterns = None\ndef compute_valid_1d_patterns(m,n):\n    # Early exit if patterns was already created.\n    global valid_m4n2_1d_patterns\n\n    if m==4  and n==2 and valid_m4n2_1d_patterns  is not None: return valid_m4n2_1d_patterns\n    patterns = torch.zeros(m)\n    patterns[:n] = 1\n    valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))\n    if m == 4  and n == 2: valid_m4n2_1d_patterns  = valid_patterns       \n    return valid_patterns\n\n\"\"\" m:n 1d structured best \"\"\"\ndef mn_1d_best(matrix, m, n):\n    # Find all possible patterns.\n    patterns = compute_valid_1d_patterns(m,n).cuda()\n\n    # Find the best m:n pattern (sum of non-masked weights).\n    mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)\n    mat,shape = reshape_1d(matrix,m)\n    pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)\n    mask[:] = patterns[pmax[:]]\n    mask = mask.view(matrix.shape)\n    return mask\n\ndef m4n2_1d(mat, density):\n    return mn_1d_best(mat, 4, 2)\n\n\"\"\"\n  Below 2d-masking related code is targeted more for training (from scratch).\n  2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop\n  phase of training algorithm. Acceleration comes from using SpMMA instructions in\n  Tensor Cores of NVIDIA Ampere GPU Architecture \n  (note: this code does not do the acceleration, GPU kernels are required for this).\n  1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern\n  along the horizontal (logical) direction.\n  During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask\n  weight tensor such that their transposed versions are also 2:4 sparse along the\n  horizontal (logical) direction. Thus, with 2d pruning, weight tensors are \n  2:4 sparse along row and column directions.\n \"\"\"\n\n\"\"\" m:n 2d structured pruning: greedy method to select mask \"\"\"\ndef mn_2d_greedy(matrix, m, n):\n    # Convert to numpy\n    mat = matrix.cpu().detach().numpy()\n    mask = np.ones(mat.shape, dtype=int)\n\n    rowCount = int(mat.shape[0]/m) * m\n    colCount = int(mat.shape[1]/m) * m\n    for rowStartIdx in range(0, rowCount, m):\n        rowEndIdx = rowStartIdx + m\n        for colStartIdx in range(0, colCount, m):\n            colEndIdx = colStartIdx + m\n            matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]))\n            maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])\n            maskSub.fill(0.0)\n            matrixVecView = matrixSub.reshape(-1)\n            maskVecView   = maskSub.reshape(-1)\n            linearIdx = np.argsort(matrixVecView)\n            matrixIdx = [(int(x/m), x % m) for x in linearIdx]\n            rowCounter = collections.Counter()\n            colCounter = collections.Counter()\n            for currIdx in range(len(linearIdx) - 1, -1, -1):\n                currMatrixEntry = matrixIdx[currIdx]\n                if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n):\n                    continue\n                #end if\n                maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0\n                rowCounter[currMatrixEntry[0]] += 1\n                colCounter[currMatrixEntry[1]] += 1\n\n    return torch.tensor(mask.cuda())\n\ndef m4n2_2d_greedy(mat, density):\n    return mn_2d_greedy(mat, 4, 2)\n\n\"\"\" return all possible m:n patterns in a mxn block. \"\"\"\nvalid_m4n2_2d_patterns = None\ndef compute_valid_2d_patterns(m,n):\n    # Early exit if patterns was already created.\n    global valid_m4n2_2d_patterns\n    if valid_m4n2_2d_patterns is not None: return valid_m4n2_2d_patterns\n\n    patterns = torch.zeros(m)\n    patterns[:n] = 1\n    patterns = list(set(permutations(patterns.tolist())))\n    patterns = patterns + patterns\n    patterns = torch.Tensor(list(set(permutations(patterns,m))))\n\n    valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)\n    valid_patterns = torch.Tensor(valid.shape[0],m,m)\n    valid_patterns[:] = patterns[valid[:]]\n\n    if m == 4  and n == 2: valid_m4n2_2d_patterns  = valid_patterns\n    return valid_patterns\n\n\"\"\" m:n 2d structured pruning: exhaustive method to select best mask \"\"\"\ndef mn_2d_best(matrix, m, n):\n    # Find all possible patterns.\n    patterns = compute_valid_2d_patterns(m,n).cuda()\n\n    # Find the best m:n pattern (sum of non-masked weights).\n    mask = torch.cuda.IntTensor(matrix.shape).fill_(1)\n    mat = reshape_2d(matrix,m,m).abs()\n    pmax = torch.argmax(torch.matmul(mat,patterns.view(patterns.shape[0],m*m).t()), dim=2)\n\n    # Copy best m:n patterns into mask.\n    mat = mat.view(mat.shape[0]*mat.shape[1],-1)\n    pmax = pmax.view(pmax.shape[0]*pmax.shape[1]).unsqueeze(1).expand(-1,mat.shape[1])\n    patterns = patterns.view(patterns.shape[0],patterns.shape[1]*patterns.shape[2])\n    mat = torch.gather(patterns,0,pmax)\n    mat = reshape_2d_inv(mat.view(matrix.shape[0]//m,matrix.shape[1]//m,m,m))\n    mask.copy_(mat.type(mask.type()))\n    return mask\n\ndef m4n2_2d_best(mat, density):\n    return mn_2d_best(mat, 4, 2)\n\n\n\"\"\" returns a sparse mask \"\"\"\ndef create_mask(tensor, pattern=\"m4n2_1d\", density=0.5):\n    # Reshape tensor and mask.\n    shape = tensor.shape\n    ttype = tensor.type()\n    t = tensor.float().contiguous()\n\n    # 1d-tensor\n    if len(shape) == 1:\n        t = t.view(1, shape[0])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 2d-tensor (in, out)\n    elif len(shape) == 2:\n        t = t.view(shape[0], shape[1])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 3d-tensor (batch, in, out)\n    elif len(shape) == 3:\n        t = t.view(shape[0]*shape[1], shape[2])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 4d-tensor (in, out, h, w)\n    elif len(shape) == 4:\n        \"\"\"\n        # transformers (bmm)\n        t = t.view(shape[0]*shape[1]*shape[2], shape[3])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n        \"\"\"\n        # convs\n        t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous()      \n        return mask.view(shape).type(ttype)\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/test/checkpointing_test_part1.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(args):\n    #\n    # PART1\n    #\n\n    torch.manual_seed(args.seed)\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.enable_sparsity()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    torch.save({\n            'step': step,\n            'verbosity': args.verbosity,\n            'seed2': args.seed2,\n            'pattern': args.pattern,\n            'whitelist': args.whitelist,\n            'allow_recompute_mask': args.allow_recompute_mask,\n            'model_state_dict': model.state_dict(),\n            'optimizer_state_dict': optimizer.state_dict(),\n            }, args.checkpoint_path)\n\nif __name__ == '__main__':\n    class Args:\n        verbosity=3\n        seed = 4873\n        seed2 = 99875\n        pattern = \"m4n2_2d_best\"\n        whitelist = [torch.nn.Linear]\n        allow_recompute_mask = True\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/test/checkpointing_test_part2.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(step, args, model_state_dict, optimizer_state_dict):\n    #\n    # PART2\n    #\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    torch.manual_seed(args.seed2)\n    model.load_state_dict(model_state_dict)\n    optimizer.load_state_dict(optimizer_state_dict)\n\n    print(\"Model sparsity is %s\" % (\"enabled\" if ASP.sparsity_is_enabled() else \"disabled\"))\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\nif __name__ == '__main__':\n    checkpoint = torch.load(\"part1.chkp\")\n    class Args:\n        verbosity = checkpoint['verbosity']\n        seed = 4873\n        seed2 = checkpoint['seed2']\n        pattern = checkpoint['pattern']\n        whitelist = checkpoint['whitelist']\n        allow_recompute_mask = checkpoint['allow_recompute_mask']\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n    args = Args()\n\n    main(checkpoint['step'], args, checkpoint['model_state_dict'], checkpoint['optimizer_state_dict'])\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/test/checkpointing_test_reference.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\n#\n# Reference run for checkpointing test (part1 + part2)\n#\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(args):\n    #\n    # PART1\n    #\n\n    torch.manual_seed(args.seed)\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(model, args.pattern, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.enable_sparsity()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    #\n    # PART 2\n    #\n\n    torch.manual_seed(args.seed2)\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\nif __name__ == '__main__':\n    class Args:\n        seed = 4873\n        seed2 = 99875\n        pattern = \"m4n2_2d_best\"\n        whitelist = [torch.nn.Linear]\n        allow_recompute_mask = True\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/sparsity/test/toy_problem.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(args):\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    # only prune linear layers, even though we also support conv1d, conv2d and conv3d\n    ASP.init_model_for_pruning(model, \"m4n2_1d\", whitelist=[torch.nn.Linear], allow_recompute_mask=True)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    # recompute sparse masks\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\n    # turn off sparsity\n    print(\"SPARSE :: \",one_ll)\n    ASP.restore_pruned_weights()\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps_2)\n\nif __name__ == '__main__':\n    class Args:\n        batch_size = 32\n        input_features = 16\n        output_features = 8\n        hidden_features = 40\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        num_dense_steps_2 = 1500\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/fmha/test_fmha.py",
    "content": "###############################################################################\n# Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n# \n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#     * Redistributions of source code must retain the above copyright\n#       notice, this list of conditions and the following disclaimer.\n#     * Redistributions in binary form must reproduce the above copyright\n#       notice, this list of conditions and the following disclaimer in the\n#       documentation and/or other materials provided with the distribution.\n#     * Neither the name of the NVIDIA CORPORATION nor the\n#       names of its contributors may be used to endorse or promote products\n#       derived from this software without specific prior written permission.\n# \n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n#\n###############################################################################\n\n\nimport sys\nimport torch\nimport numpy as np\nimport unittest\nimport math\n\nimport fmhalib as mha\n\ndef py_mha(qkv, amask, b, s, h, d):\n    qkv = qkv.view(b, s, h, 3, d)\n    q = qkv[:, :, :, 0, :].permute(0,2,1,3)\n    k = qkv[:, :, :, 1, :].permute(0,2,1,3)\n    v = qkv[:, :, :, 2, :].permute(0,2,1,3)\n    p = torch.matmul(q.float(), k.permute(0,1,3,2).float())\n    p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0\n    s = torch.softmax(p_masked, -1).to(qkv.dtype)\n    ctx = torch.matmul(s, v)\n    ctx = ctx.permute(0,2,1,3).contiguous()\n\n    ctx.retain_grad()\n\n    return ctx\n\nclass TestFMHA(unittest.TestCase):\n\n    def run_test(self, s, b):\n        print(f'Test s={s} b={b}')\n\n        torch.manual_seed(1234)\n        torch.cuda.manual_seed(1234)\n        \n        dtype = torch.float16\n        device = torch.device('cuda')\n\n        h = 16 \n        d = 64\n    \n        slens = [s] * b \n        a = torch.tensor(np.array([0] + slens), dtype=torch.int32)\n        amask = torch.ones(b,h,s,s, dtype=dtype, device=device)\n        seqlens = torch.tensor(slens, dtype=torch.int32, device=device)\n        cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)\n        total = cu_seqlens[-1].item()\n    \n        qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype)\n    \n        qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d)\n    \n        qkv.requires_grad = True\n    \n        if b < 4:\n            ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)\n        else:\n            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)\n        ctx = ctx.view(b,s,h,d)\n    \n        ctx_ref = py_mha(qkv, amask, b,s,h,d)\n        self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))\n    \n        labels = torch.randn_like(ctx_ref)\n        diff = ctx_ref - labels\n        l = (diff * diff).sum() / b\n        l.backward()\n    \n        dw = ctx_ref.grad.permute(0,2,1,3) \n    \n        dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()\n    \n        if b < 4:\n            dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)\n        else:\n            dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)\n        \n        dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)\n    \n        self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))\n\n    def test_128(self):\n        self.run_test(128, 32)\n\n    def test_256(self):\n        self.run_test(256, 32)\n\n    def test_384(self):\n        self.run_test(384, 32)\n\n    def test_512(self):\n        self.run_test(512, 32)\n        self.run_test(512, 2)\n        self.run_test(512, 3)\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/layer_norm/test_fast_layer_norm.py",
    "content": "import torch\nimport unittest\nimport numpy as np\n\nimport torch.nn.functional as F\n\nfrom apex.contrib.layer_norm import FastLayerNorm\n\nimport fast_layer_norm as fln\n\n\nclass GPUTimer:\n    def __init__(self, stream):\n        self.start_ = torch.cuda.Event(enable_timing=True)\n        self.stop_ = torch.cuda.Event(enable_timing=True)\n        self.stream_ = stream\n    def start(self):\n        self.stream_.record_event(self.start_)\n    def stop(self):\n        self.stream_.record_event(self.stop_)\n    def sync(self):\n        self.stream_.synchronize()\n    def millis(self):\n        return self.start_.elapsed_time(self.stop_)\n\ndef size_in_bytes(t):\n    return torch.numel(t) * t.element_size()\ndef abs_err(x, y):\n    xf = x.float()\n    yf = y.float()\n    return ((xf-yf).abs().sum() / yf.abs().sum()).item()\n\n\n\nclass TestFastLayerNorm(unittest.TestCase):\n    \n    def setUp(self, seed=1234):\n        seed = 1234\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def test_ln_fp32(self):\n        self.run_test_layer_norm(torch.float32, atol=1e-5)\n    def test_ln_fp16(self):\n        self.run_test_layer_norm(torch.float16, atol=1e-2, rtol=1e-3)\n\n    def run_test_layer_norm(self, dtype, atol, rtol=1e-5):\n        device = torch.device('cuda')\n        s = 512\n        b = 32\n        hidden_size = 1024\n        epsilon = 1e-5\n\n        x = torch.randn((s,b,hidden_size), dtype=dtype, device=device)  \n        beta = torch.randn(hidden_size, dtype=dtype, device=device)  \n        gamma = torch.randn(hidden_size, dtype=dtype, device=device)\n        x.requires_grad = True\n        beta.requires_grad = True\n        gamma.requires_grad = True\n\n        x2 = x.clone().detach()\n        beta2 = beta.clone().detach()\n        gamma2 = gamma.clone().detach()\n        x2.requires_grad = True\n        beta2.requires_grad = True\n        gamma2.requires_grad = True\n               \n        dummy_label = torch.randn_like(x)\n\n        y = F.layer_norm(x, [hidden_size], gamma, beta, epsilon)\n\n        diff = y-dummy_label\n        l = (diff * diff).sum() / b\n        l.backward()\n\n        fln = FastLayerNorm(hidden_size).cuda()\n        fln.load_state_dict({'bias': beta2, 'weight':gamma2})\n        if dtype == torch.float16:\n            fln = fln.half()\n\n        y2 = fln(x2)\n        diff2 = (y2 - dummy_label)\n        l2 = (diff2 * diff2).sum() / b\n\n        l2.backward()\n\n        self.assertTrue(torch.allclose(y2, y, atol=atol, rtol=rtol))\n        self.assertTrue(torch.allclose(x2.grad, x.grad, atol=atol,rtol=rtol))\n        self.assertTrue(torch.allclose(fln.bias.grad, beta.grad, atol=atol, rtol=rtol))\n        self.assertTrue(torch.allclose(fln.weight.grad, gamma.grad, atol=atol, rtol=rtol))\n    \n\n\n    def test_performance(self):\n        print()\n        runs = 1000\n        device = torch.device('cuda')\n        dtype =torch.float16\n        s = 512\n        b = 32\n        hidden_size = 1024\n        epsilon = 1e-5\n\n        x = torch.randn((s*b,hidden_size), dtype=dtype, device=device)  \n        beta = torch.randn(hidden_size, dtype=dtype, device=device)  \n        gamma = torch.randn(hidden_size, dtype=dtype, device=device)\n        dy = torch.randn_like(x)\n \n\n        stream = torch.cuda.Stream()\n        with torch.cuda.stream(stream):\n\n            timer = GPUTimer(stream)\n\n            #warmup\n            for r in range(runs):\n                y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)\n           \n           \n            timer.start()\n            for r in range(runs):\n                y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)\n            timer.stop()\n            timer.sync()\n\n            total_bytes_fwd = (size_in_bytes(x) \n                             + size_in_bytes(y) \n                             + size_in_bytes(gamma) \n                             + size_in_bytes(beta) \n                             + size_in_bytes(mu) \n                             + size_in_bytes(rsigma)\n                             )\n\n            ms_fwd = timer.millis() / runs\n            print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd ))\n         \n\n            timer.start()\n            for r in range(runs):\n                dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma)\n            timer.stop()\n            timer.sync()\n\n            total_bytes_bwd = (size_in_bytes(x) \n                             + size_in_bytes(dx)\n                             + size_in_bytes(dy) \n                             + size_in_bytes(gamma) \n                             + size_in_bytes(dgamma)  \n                             + size_in_bytes(dbeta)  \n                             + size_in_bytes(mu) \n                             + size_in_bytes(rsigma)\n                             )\n\n\n            ms_bwd = timer.millis() / runs\n            print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd ))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nclass EncdecMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=False, \n                                             impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=False, \n                                             impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_encdec_multihead_attn(self) :\n        grads         = torch.randn_like(self.tst_inputs_q)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n    \n    def test_encdec_multihead_attn_time_mask(self) :\n        grads          = torch.randn_like(self.tst_inputs_q)\n        time_mask_byte = torch.triu(torch.ones(self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        time_mask_bool = time_mask_byte.to(torch.bool)\n        \n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_bool,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_byte,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n    \n    def test_encdec_multihead_attn_pad_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs_q)\n        pad_mask_byte = torch.tril(torch.ones(self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        pad_mask_bool = pad_mask_byte.to(torch.bool)\n        \n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=pad_mask_bool, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=pad_mask_byte, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nclass EncdecMultiheadAttnNormAddTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=True, \n                                             impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=True, \n                                             impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_encdec_multihead_attn_norm_add(self) :\n        grads         = torch.randn_like(self.tst_inputs_q)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\n\nclass SelfMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=True, \n                                           include_norm_add=False, \n                                           separate_qkv_params=True, \n                                           mask_additive=True, \n                                           impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=True, \n                                           include_norm_add=False, \n                                           separate_qkv_params=True, \n                                           mask_additive=True, \n                                           impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    \n    def test_self_multihead_attn_additive_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n        mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=mask, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=mask, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py",
    "content": "import torch\nimport unittest\nimport torch.nn.functional as F\nfrom apex.contrib.multihead_attn import fast_mask_softmax_dropout_func\n\nclass FusedSoftmaxTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.mask = (torch.randn(self.sequences,self.seq_length)>0).cuda()\n        self.mask = self.mask.half()*-10000\n        self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        \n        self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)\n\n    def test_fused_softmax(self) :\n        grads = torch.randn_like(self.tst_inputs)\n        y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length)\n        y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2)\n        y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length) \n        y_ref = F.softmax(y_ref, dim=-1)\n        y_ref = torch._fused_dropout(y_ref, 1.0)    \n   \n        y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0)        \n        y_ref[0].backward(grads)\n        y_tst.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(y_ref[0], y_tst, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/multihead_attn/test_self_multihead_attn.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\n\nclass SelfMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=False, \n                                           impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=False, \n                                           impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_self_multihead_attn(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\n    def test_self_multihead_attn_time_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n        time_mask_byte= torch.triu(torch.ones(self.tst_inputs.size(0), self.tst_inputs.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        time_mask_bool= time_mask_byte.to(torch.bool)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_bool,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_byte,\n                                               is_training=True)\n\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n    \n    def test_self_multihead_attn_pad_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n        pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        pad_mask_bool = pad_mask_byte.to(torch.bool)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=pad_mask_bool, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=pad_mask_byte, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\n\nclass SelfMultiheadAttnNormAddTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=True, \n                                           impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=True, \n                                           impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_self_multihead_attn_norm_add(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/test_label_smoothing.py",
    "content": "import torch\nfrom apex.contrib import xentropy as label_smoothing\nimport unittest\n\nimport warnings\nimport random\nimport numpy as np\nimport time\n\ndef label_smoothing_raw(x, target, padding_idx, smoothing):\n    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)\n\n    non_pad_mask = (target != padding_idx)\n    nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))\n    nll_loss = nll_loss.squeeze(1)[non_pad_mask]\n    smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask]\n    loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss\n    return loss\n\ndef label_smoothing_opt_1(x, target, padding_idx, smoothing):\n    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)\n\n    pad_mask = (target == padding_idx)\n    ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)\n    smooth_loss = logprobs.mean(dim=-1)\n    loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss\n    loss.masked_fill_(pad_mask, 0)\n    return loss\n\nclass LabelSmoothingTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        # Set pytorch print precision\n        torch.set_printoptions(precision=10)\n\n    def gen_test_inputs(self, N, T, H, smoothing, padding_idx):\n        logits = torch.randn((N*T, H), dtype=torch.half, device='cuda',\n            requires_grad=True)\n        labels = torch.randint(0, H, [N*T], device='cuda')\n        for i in random.sample(range(N*T), N*T//6):\n            labels[i] = padding_idx\n        half_to_float = (logits.dtype == torch.half)\n\n        return logits, labels, half_to_float\n\n    def print_max_diff_elem(self, ref, tst):\n        ref, tst = ref.flatten(), tst.flatten()\n        diff = (ref - tst).abs().max()\n        idx = (ref - tst).abs().argmax()\n        print(\"Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}\".format(\n            idx, diff, ref[idx], tst[idx]))\n\n    def test_label_smoothing_function(self):\n        # Set label smoothing configuration\n        smoothing, padding_idx = 0.1, 0\n        N, T, H = 128, 74, 32320\n        iters = 10\n        loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply\n\n        for i in range(iters):\n            logits, labels, half_to_float = self.gen_test_inputs(\n                N, T, H, smoothing, padding_idx)\n    \n            # Run original softmax cross entropy with label smoothing\n            logits.grad = None\n            losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)\n            loss = losses.sum()\n            loss.backward()\n            \n            ref_loss = loss.clone().detach()\n            ref_grad = logits.grad.clone().detach()\n\n            # Run optimized softmax cross entropy with label smoothing\n            logits.grad = None\n            losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)\n            loss = losses.sum()\n            loss.backward()\n\n            val_loss = loss.clone().detach()\n            val_grad = logits.grad.clone().detach()\n\n            # Validate\n            self.print_max_diff_elem(ref_grad, val_grad)\n            self.assertTrue(torch.allclose(ref_loss, val_loss, atol=1e-5, rtol=1e-5))\n            self.assertTrue(torch.allclose(ref_grad, val_grad, atol=1e-5, rtol=1e-5))\n\n    def test_label_smoothing_perf(self):\n        # Set label smoothing configuration\n        smoothing, padding_idx = 0.1, 0\n        N, T, H = 128, 74, 32320\n        iters = 1000\n        loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply\n        print()\n\n        logits, labels, half_to_float = self.gen_test_inputs(\n            N, T, H, smoothing, padding_idx)\n    \n        # Run original softmax cross entropy with label smoothing\n        torch.cuda.synchronize()\n        ts = time.time()\n        for i in range(iters):\n            logits.grad = None\n            losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)\n            loss = losses.sum() / N\n            loss.backward()\n        torch.cuda.synchronize()\n        print(\"Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}\".format(\n            time.time() - ts, iters, logits.grad.norm()))\n            \n        # Run optimized softmax cross entropy with label smoothing\n        torch.cuda.synchronize()\n        ts = time.time()\n        for i in range(iters):\n            logits.grad = None\n            losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)\n            loss = losses.sum() / N\n            loss.backward()\n        torch.cuda.synchronize()\n        print(\"Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}\".format(\n            time.time() - ts, iters, logits.grad.norm()))\n\nif __name__ == '__main__':\n    unittest.main()\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/transducer/test_transducer_joint.py",
    "content": "import torch\nimport unittest\nfrom apex.contrib.transducer import TransducerJoint\nimport transducer_ref\n\nclass TransducerJointTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def gen_input(self, for_vector_kernel):\n        self.B = 4\n        T_min = 51\n        T_max = 101\n        U_min = 12\n        U_max = 25\n        if for_vector_kernel:\n            H = 512\n        else:\n            H = 509\n        dtype = torch.float16\n        device = \"cuda\"\n\n        self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)\n        self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)\n        self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)\n        self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) \n        self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)\n        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max\n        self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max\n        self.dropout_prob = 0.5\n\n        # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by \n        # the loss function\n        for b in range(self.B):\n            self.h_grad[b, self.f_len[b]:, :, :] = 0\n            self.h_grad[b, :, self.g_len[b]:, :] = 0\n        self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)\n        \n\n    def _pack(self, x, f_len, g_len):\n        B = x.size(0)\n        list_x = []\n        for b in range(B):\n            list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]\n            x_row = torch.cat(list_x_row)\n            list_x.append(x_row)\n        x_packed = torch.cat(list_x).data.clone()\n        x_packed.requires_grad = True\n        batch_offset = torch.cumsum(f_len * g_len, dim=0)\n        return x_packed\n\n    def _unpack(self, x, f_len, g_len):\n        batch_offset = torch.cumsum(f_len * g_len, dim=0)\n        x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)\n        B = self.h_grad.size(0)\n        H = self.h_grad.size(-1)\n        for b in range(B):\n            my_batch_offset = 0 if b == 0 else batch_offset[b-1]\n            my_f_len = f_len[b]\n            my_g_len = g_len[b]\n            for t in range(my_f_len):\n                x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : \n                                                my_batch_offset + t*my_g_len + my_g_len]\n        return x_unpacked\n        \n    def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):\n        self.gen_input(for_vector_kernel=for_vector_kernel)\n        # Generate reference\n        f_ref = self.f_tst.data.clone()\n        g_ref = self.g_tst.data.clone()\n        f_ref.requires_grad = True\n        g_ref.requires_grad = True\n        \n        my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, \n                                    dropout_prob=self.dropout_prob, probe_mask=True)\n        if not pack_output:\n            h_tst = my_joint(   f=self.f_tst, \n                                g=self.g_tst, \n                                f_len=self.f_len, \n                                g_len=self.g_len)\n            h_tst.backward(self.h_grad)\n            if dropout:\n                mask = my_joint.mask_probe[0]\n        else:\n            batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)\n            h_tst = my_joint(   f=self.f_tst, \n                                g=self.g_tst, \n                                f_len=self.f_len, \n                                g_len=self.g_len, \n                                batch_offset=batch_offset, \n                                packed_batch=batch_offset[-1])\n            h_tst.backward(self.h_grad_packed)\n            if dropout:\n                mask_packed = my_joint.mask_probe[0]\n                mask = self._unpack(mask_packed, self.f_len, self.g_len)\n\n        # reference\n        h_ref, f_grad_ref, g_grad_ref \\\n            = transducer_ref.transducer_joint_reference(f=f_ref, \n                                                        g=g_ref, \n                                                        h_grad=self.h_grad, \n                                                        f_len=self.f_len, \n                                                        g_len=self.g_len, \n                                                        pack_output=pack_output,\n                                                        relu=relu,\n                                                        dropout=dropout,\n                                                        dropout_prob=self.dropout_prob,\n                                                        mask=mask if dropout else None)\n        \n        f_grad_tst = self.f_tst.grad\n        g_grad_tst = self.g_tst.grad\n        \n        self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))\n\n    def test_transducer_joint(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)\n\n    def test_transducer_joint_vec(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)\n\n    def test_transducer_joint_pack(self):\n        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)\n\n    def test_transducer_joint_vec_pack(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)\n\n    def test_transducer_joint_relu(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)\n\n    def test_transducer_joint_vec_relu(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)\n\n    def test_transducer_joint_pack_relu(self):\n        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)\n\n    def test_transducer_joint_vec_pack_relu(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)\n\n    def test_transducer_joint_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)\n\n    def test_transducer_joint_vec_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)\n\n    def test_transducer_joint_pack_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)\n\n    def test_transducer_joint_vec_pack_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/transducer/test_transducer_loss.py",
    "content": "import torch\nimport unittest\nfrom apex.contrib.transducer import TransducerLoss\nimport transducer_ref\n\nclass TransducerLossTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def gen_input(self, scalar_t, for_vector_kernel):\n        self.B = 5\n        T_min = 23\n        T_max = 51\n        U_min = 12\n        U_max = 25\n        V = 16 if for_vector_kernel else 14\n        self.blank_idx = V - 1\n        device = \"cuda\"\n\n        self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, \n                                    device=device)\n        self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device)\n        self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) \n        self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device)\n        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max\n        self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1\n        self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)\n        # Generate reference\n        x_ref = self.x_tst.data.clone()\n        x_ref.requires_grad = True\n        loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0)\n        _, _, self.grad_ref, self.loss_ref \\\n            = transducer_ref.transducer_loss_reference( x=x_ref, \n                                                        label=self.y, \n                                                        f_len=self.f_len, \n                                                        y_len=self.y_len, \n                                                        blank_idx=self.blank_idx, \n                                                        loss_grad=loss_grad)\n\n    def _pack(self, x):\n        list_x = []\n        for b in range(self.B):\n            list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])]\n            x_row = torch.cat(list_x_row)\n            list_x.append(x_row)\n        x_packed = torch.cat(list_x).data.clone()\n        x_packed.requires_grad = True\n        batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0)\n        return x_packed, batch_offset\n\n    def _unpack(self, x):\n        x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1), \n                                    dtype=x.dtype, device=x.device)\n        for b in range(self.B):\n            my_batch_offset = 0 if b == 0 else self.batch_offset[b-1]\n            my_f_len = self.f_len[b]\n            my_g_len = self.y_len[b] + 1\n            for t in range(my_f_len):\n                for u in range(my_g_len):\n                    x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]\n        return x_unpacked\n\n    def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):\n        self.gen_input(scalar_t, for_vector_kernel)\n        my_loss = TransducerLoss(  fuse_softmax_backward=fuse_softmax_backward, \n                                    packed_input=packed_input) \n        if not packed_input:\n            loss_tst = my_loss( x=self.x_tst,\n                                label=self.y, \n                                f_len=self.f_len, \n                                y_len=self.y_len, \n                                blank_idx=self.blank_idx)\n            loss_tst.mean().backward() \n            grad_tst = self.x_tst.grad\n        else:\n            loss_tst = my_loss( x=self.x_tst_packed,\n                                label=self.y, \n                                f_len=self.f_len, \n                                y_len=self.y_len, \n                                blank_idx=self.blank_idx,\n                                batch_offset=self.batch_offset, \n                                max_f_len=max(self.f_len))\n            loss_tst.mean().backward()\n            grad_tst_packed = self.x_tst_packed.grad\n            grad_tst = self._unpack(grad_tst_packed)\n        \n        return loss_tst, grad_tst\n\n    def test_transducer_loss_fp32(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float32,\n                                                        fuse_softmax_backward=False,\n                                                        packed_input=False,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))\n\n    def test_transducer_loss_fp16(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=False,\n                                                        packed_input=False,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n    def test_transducer_loss_fp16_backward_fusion(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=True,\n                                                        packed_input=False,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n    def test_transducer_loss_fp16_backward_fusion_packed(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=True,\n                                                        packed_input=True,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n    def test_transducer_loss_fp16_backward_fusion_packed_vec(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=True,\n                                                        packed_input=True,\n                                                        for_vector_kernel=True)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "KoSentenceT5/apex/contrib/test/transducer/transducer_ref.py",
    "content": "import torch\nimport numpy as np\nimport pdb\n\ndef transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):\n    def log_sum_exp(a, b):\n        if (a >= b):\n            return a + torch.log(1 + torch.exp(b-a))\n        else:\n            return b + torch.log(1 + torch.exp(a-b))\n\n    def forward_alpha(x, label, f_len, y_len, blank_idx):\n        B, T, U, V = x.size()\n        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype\n        alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)\n        for b in range(B):\n            alpha[b, 0, 0] = 0\n            for t in range(1, f_len[b]):\n                alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]\n            for u in range(1, y_len[b]+1):\n                alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]\n            for t in range(1, f_len[b]):\n                for u in range(1, y_len[b]+1):\n                    curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]\n                    next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]\n                    alpha[b, t, u] = log_sum_exp(curr_, next_) \n        return alpha\n\n    def forward_beta(x, label, f_len, y_len, blank_idx):\n        B, T, U, V = x.shape\n        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype\n        beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)\n        for b in range(B):\n            beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]\n            for t in range(f_len[b]-2, -1, -1):\n                beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx] \n            for u in range(y_len[b]-1, -1, -1):\n                beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]\n            for t in range(f_len[b]-2, -1, -1):\n                for u in range(y_len[b]-1, -1, -1):\n                    curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx] \n                    next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]\n                    beta[b, t, u] = log_sum_exp(curr_, next_) \n        return beta\n\n    def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):\n        grad = torch.zeros_like(x)\n        B, T, U, V = x.size()\n        for b in range(B):\n            common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]\n            # next\n            for u in range(y_len[b]):\n                grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u] \n                                                        + beta[b, :f_len[b], u+1] \n                                                        + x[b, :f_len[b], u, label[b, u]])\n\n            # current\n            grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \\\n                = -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1] \n                    + beta[b, 1:f_len[b], :y_len[b]+1] \n                    + x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])\n\n            grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]\n                                                         + x[b, f_len[b]-1, y_len[b], blank_idx])\n     \n        return grad\n\n    x_log = torch.nn.functional.log_softmax(x, dim=-1)\n    alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)\n    beta = forward_beta(x_log, label, f_len, y_len, blank_idx)\n    grad = backward(x_log, label, f_len, y_len, alpha, beta, \n                        loss_grad, blank_idx)\n    x_log.backward(grad)\n    loss = -beta[:, 0, 0]\n    loss = loss.to(x.dtype)\n    return alpha, beta, x.grad, loss\n\n\ndef transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout, \n                                dropout_prob=0, mask=None):\n    if dropout and mask == None:\n        raise NotImplementedError(\"mask needs to supplied to test dropout.\")\n    B, T, H = f.size()\n    U = g.size(1)\n    f_expand = f.unsqueeze(dim=2)\n    g_expand = g.unsqueeze(dim=1)\n    h = f_expand + g_expand\n    if relu:\n        h = torch.nn.functional.relu(h)\n    if dropout:\n        h *= mask\n        scale = 1/(1-dropout_prob)\n        h *= scale\n    h.backward(h_grad)\n\n    if pack_output == False:\n        # intentionally set don't-care region to -1 to test if transducer joint\n        # write these regions to avoid NaN and inf\n        for b in range(B):\n            h[b, f_len[b]:] = -1\n            h[b, :, g_len[b]:] = -1\n\n        return h, f.grad, g.grad \n\n    # packing\n    list_to_pack = []\n    for b in range(B):\n        list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))\n    h_packed = torch.cat(list_to_pack)\n    return h_packed, f.grad, g.grad\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/transducer/__init__.py",
    "content": "from .transducer import TransducerJoint\nfrom .transducer import TransducerLoss"
  },
  {
    "path": "KoSentenceT5/apex/contrib/transducer/transducer.py",
    "content": "import torch\nimport transducer_loss_cuda\nimport transducer_joint_cuda\n\nclass TransducerJoint(torch.nn.Module):\n    \"\"\"Transducer joint\n    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural \n    Networks\n\n    Arguments:\n        pack_output (bool, optional): whether to pack the output in a compact form with don't-care \n        data being removed. (default: False)\n        relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1  \n        (default: False)\n        dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1  \n        (default: False)\n        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. \n            (default: 1)\n        fwd_tile_size (int, optional): tile size used in forward operation. This argument will be \n        ignored if opt != 1. (default: 4) \n        dropout_prob (float, optional): dropout probability. (default: 0.0)\n        probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout\n        operation. When this argument is set to True, the mask can be accessed through \n        self.mask_probe. (default: false)\n    \"\"\"\n\n    def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, \n                    dropout_prob=0, probe_mask=False):\n        super(TransducerJoint, self).__init__() \n        self.pack_output = pack_output\n        self.relu = relu\n        self.dropout = dropout\n        self.dropout_prob = dropout_prob\n        self.opt = opt\n        self.fwd_tile_size = fwd_tile_size\n        self.dummy_batch_offset = torch.empty(0)\n        masked = self.relu or self.dropout\n        self.mask_probe = [] if masked and probe_mask else None\n        if masked and opt != 1:\n            raise NotImplementedError(\"ReLU and dropout fusion is only supported with opt=1\")\n\n\n    def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):\n        \"\"\"Forward operation of transducer joint\n\n        Arguments:\n            f (tensor): transcription vector from encode block of shape (B, T, H).\n            g (tensor): prediction vector form predict block of shape (B, U, H).\n            f_len (tensor): length of transcription vector for each batch.\n            g_len (tensor): length of prediction vector minus 1 for each batch.\n            batch_offset (tensor, optional): tensor containing the offset of each batch\n                in the results. For example, batch offset can be obtained from: \n                batch_offset = torch.cumsum(f_len*g_len, dim=0)\n                This argument is required if pack_output == True, and is ignored if \n                pack_output == False. (default: None)\n            packed_batch (int, optional): the batch size after packing. This argument is \n                ignored if pack_output == False. (default: 0)\n        \"\"\"\n        my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset\n        if self.pack_output and (batch_offset is None or packed_batch == 0):\n            raise Exception(\"Please specify batch_offset and packed_batch when packing is enabled\")\n        dropout =  self.dropout and self.training    # only dropout for training\n        return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, \n                                            my_batch_offset, packed_batch, self.opt, \n                                            self.fwd_tile_size, self.dropout_prob, self.mask_probe)\n\n\nclass TransducerLoss(torch.nn.Module):\n    \"\"\"Transducer loss\n    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural \n    Networks\n\n    Arguments:\n        fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with\n            softmax. (default: True)\n        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized \n            algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)\n        packed_input (bool, optional): whether to pack the output in a compact form with don't-care \n        data being removed. (default: False)\n    \"\"\"\n    def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):\n        super(TransducerLoss, self).__init__() \n        self.fuse_softmax_backward = fuse_softmax_backward\n        self.opt = opt\n        self.packed_input = packed_input\n        self.dummy_batch_offset = torch.empty(0)\n\n\n    def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None, \n                debug_list=None):\n        \"\"\"Forward operation of transducer joint\n\n        Arguments:\n            x (tensor): input tensor to the loss function with a shape of (B, T, U, H).\n            label (tensor): labels for the input data.\n            f_len (tensor): lengths of the inputs in the time dimension for each batch.\n            y_len (tensor): lengths of the labels for each batch.\n            blank_idx (int): index for the null symbol.\n            batch_offset (tensor, optional): tensor containing the offset of each batch\n                in the input. For example, batch offset can be obtained from: \n                batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)\n                This argument is required if packed_input == True, and is ignored if \n                packed_input == False. (default: None)\n            max_f_len (int, optional): maximum length of the input in the time dimension.\n                For example, it can be obtained as \n                max_f_len = max(f_len)\n                This argument is required if packed_input == True, and is ignored if \n                packed_input == False. (default: None)\n                (default: None)\n            debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated \n                in the forward operation will be attached to this list for debug purpose. \n                (default: None)\n        \"\"\"\n        if self.packed_input:\n            if batch_offset is None or max_f_len is None:\n                raise Exception(\"Please specify batch_offset and max_f_len when packing is \\\n                                    enabled\") \n            my_batch_offset = batch_offset\n            my_max_f_len = max_f_len\n        else:\n            my_batch_offset = self.dummy_batch_offset\n            my_max_f_len = x.size(1)\n        return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len, \n                                            blank_idx, self.fuse_softmax_backward, debug_list, \n                                            self.opt, self.packed_input)\n\nclass TransducerLossFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, \n                fuse_softmax_backward, debug_list, opt, packed_input):\n        if fuse_softmax_backward == False:\n            with torch.enable_grad():\n                x = torch.nn.functional.log_softmax(x, dim=-1)\n        else:\n            x = torch.nn.functional.log_softmax(x, dim=-1)\n        alpha, beta, loss = transducer_loss_cuda.forward(   x, label, f_len, y_len, batch_offset, \n                                                            max_f_len, blank_idx, opt, packed_input)\n        if debug_list == []:\n            debug_list += [alpha, beta]\n        ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)\n        ctx.blank_idx = blank_idx\n        ctx.fuse_softmax_backward = fuse_softmax_backward\n        ctx.opt = opt\n        ctx.packed_input = packed_input\n        ctx.max_f_len = max_f_len\n        return loss\n\n    @staticmethod\n    def backward(ctx, loss_grad):\n        x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors\n        x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label, \n                                                batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt, \n                                                ctx.fuse_softmax_backward, ctx.packed_input)\n        if ctx.fuse_softmax_backward == False:\n            x_grad = x.backward(x_grad)\n        return x_grad, None, None, None, None, None, None, None, None, None, None\n\nclass TransducerJointFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, \n                opt, fwd_tile_size, dropout_prob, mask_probe):\n        h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, \n                                            pack_output, relu, dropout, dropout_prob, fwd_tile_size)\n        masked = relu or dropout\n        if masked:\n            ctx.save_for_backward(h[1], f_len, g_len, batch_offset)\n            if mask_probe is not None:\n                mask_probe.append(h[1])\n        else:\n            ctx.save_for_backward(f_len, g_len, batch_offset)\n\n        ctx.pack_output = pack_output\n        ctx.masked = relu or dropout\n        ctx.max_f_len = f.size(1)\n        ctx.max_g_len = g.size(1)\n        ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1\n        return h[0]\n\n    @staticmethod\n    def backward(ctx, loss_grad):\n        if ctx.masked:\n            mask, f_len, g_len, batch_offset = ctx.saved_tensors\n            inp = [loss_grad, mask]\n        else:\n            f_len, g_len, batch_offset = ctx.saved_tensors\n            inp = [loss_grad]\n\n        f_grad, g_grad = transducer_joint_cuda.backward(    inp, f_len, g_len, batch_offset, \n                                                            ctx.max_f_len, ctx.max_g_len, \n                                                            ctx.pack_output, ctx.scale)\n\n        return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \\\n                None, None, None\n\n\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/xentropy/__init__.py",
    "content": "try:\n    import torch\n    import xentropy_cuda\n    from .softmax_xentropy import SoftmaxCrossEntropyLoss\n    del torch\n    del xentropy_cuda\n    del softmax_xentropy\nexcept ImportError as err:\n    print(\"apex was installed without --xentropy flag, contrib.xentropy is not available\")\n"
  },
  {
    "path": "KoSentenceT5/apex/contrib/xentropy/softmax_xentropy.py",
    "content": "import torch\nimport xentropy_cuda\n\nclass SoftmaxCrossEntropyLoss(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):\n        losses, max_log_sum_exp = xentropy_cuda.forward(\n            logits, labels, smoothing, half_to_float)\n        losses.masked_fill_(labels==padding_idx, 0)\n\n        ctx.save_for_backward(logits, max_log_sum_exp, labels,\n            torch.FloatTensor([smoothing]),\n            torch.LongTensor([padding_idx]))\n\n        return losses\n\n    @staticmethod\n    def backward(ctx, grad_loss):\n        logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors\n\n        if not grad_loss.is_contiguous():\n            grad_loss = grad_loss.contiguous()\n        grad_loss.masked_fill_(labels==padding_idx.item(), 0)\n        grad_logits = xentropy_cuda.backward(\n            grad_loss.contiguous(), logits, max_log_sum_exp,\n            labels, smoothing.item())\n\n        return grad_logits, None, None, None, None\n"
  },
  {
    "path": "KoSentenceT5/apex/fp16_utils/README.md",
    "content": "fp16_optimizer.py contains `FP16_Optimizer`, a Python class designed to wrap an existing Pytorch optimizer and automatically enable master parameters and loss scaling in a manner transparent to the user.  To use `FP16_Optimizer`, only two lines of one's Python model need to change.\n\n#### [FP16_Optimizer API documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling)\n\n#### [Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple)\n\n#### [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)\n\n#### [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model)\n\n\nfp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses.  \n\n#### [Manual management documentation](https://nvidia.github.io/apex/fp16_utils.html#manual-master-parameter-management)\n\nThe [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) and [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) directories also contain `main.py` files that demonstrate manual management of master parameters and static loss scaling.  These examples illustrate what sort of operations `FP16_Optimizer` is performing automatically.\n"
  },
  {
    "path": "KoSentenceT5/apex/fp16_utils/__init__.py",
    "content": "from .fp16util import (\n    BN_convert_float,\n    network_to_half,\n    prep_param_lists,\n    model_grads_to_master_grads,\n    master_params_to_model_params,\n    tofp16,\n    to_python_float,\n    clip_grad_norm,\n    convert_module,\n    convert_network,\n    FP16Model,\n)\n\nfrom .fp16_optimizer import FP16_Optimizer\nfrom .loss_scaler import LossScaler, DynamicLossScaler\n"
  },
  {
    "path": "KoSentenceT5/apex/fp16_utils/fp16_optimizer.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.nn.parameter import Parameter\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\n\nfrom ..amp._amp_state import _amp_state, maybe_print\nfrom ..amp.scaler import LossScaler\nfrom ..multi_tensor_apply import multi_tensor_applier\nfrom .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm\n\n# TODO:  Update overflow check + downscale to use Carl's fused kernel.\nclass FP16_Optimizer(object):\n    def __init__(self, \n                 init_optimizer, \n                 static_loss_scale=1.0, \n                 dynamic_loss_scale=False,\n                 dynamic_loss_args=None,\n                 verbose=True):\n        print(\"Warning:  FP16_Optimizer is deprecated and dangerous, and will be deleted soon.  \"\n              \"If it still works, you're probably getting lucky.  \"\n              \"For mixed precision, use the documented API https://nvidia.github.io/apex/amp.html, with opt_level=O1.\")\n\n        if not torch.cuda.is_available:\n            raise SystemError(\"Cannot use fp16 without CUDA.\")\n\n        self.verbose = verbose\n\n        self.optimizer = init_optimizer\n        # init_state_dict sets up an alternative way to cast per-param state tensors.\n        # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.\n        # init_state_dict = init_optimizer.state_dict()\n\n        self.fp16_groups = []\n        self.fp32_from_fp16_groups = []\n        self.fp32_from_fp32_groups = []\n        for i, param_group in enumerate(self.optimizer.param_groups):\n            self.maybe_print(\"FP16_Optimizer processing param group {}:\".format(i))\n            fp16_params_this_group = []\n            fp32_params_this_group = []\n            fp32_from_fp16_params_this_group = []\n            for i, param in enumerate(param_group['params']):\n                if param.requires_grad:\n                    if param.type() == 'torch.cuda.HalfTensor':\n                        self.maybe_print(\"FP16_Optimizer received torch.cuda.HalfTensor with {}\"\n                                         .format(param.size()))\n                        fp16_params_this_group.append(param)\n                        master_param = param.detach().clone().float()\n                        master_param.requires_grad = True\n                        param_group['params'][i] = master_param\n                        fp32_from_fp16_params_this_group.append(master_param)\n                        # Reset existing state dict key to the new master param.\n                        # We still need to recast per-param state tensors, if any, to FP32.\n                        if param in self.optimizer.state:\n                           self.optimizer.state[master_param] = self.optimizer.state.pop(param) \n                    elif param.type() == 'torch.cuda.FloatTensor':\n                        self.maybe_print(\"FP16_Optimizer received torch.cuda.FloatTensor with {}\"\n                                         .format(param.size()))\n                        fp32_params_this_group.append(param)\n                        param_group['params'][i] = param\n                    else:\n                        raise TypeError(\"Wrapped parameters must be either \"\n                                        \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"  \n                                        \"Received {}\".format(param.type()))\n            \n            self.fp16_groups.append(fp16_params_this_group)\n            self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)\n            self.fp32_from_fp32_groups.append(fp32_params_this_group)\n\n        self.all_fp16_params = []\n        for group in self.fp16_groups:\n            self.all_fp16_params += group\n\n        self.all_fp32_from_fp16_params = []\n        for group in self.fp32_from_fp16_groups:\n            self.all_fp32_from_fp16_params += group\n\n        self.all_fp32_from_fp32_params = []\n        for group in self.fp32_from_fp32_groups:\n            self.all_fp32_from_fp32_params += group\n\n        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors\n        self.optimizer.load_state_dict(self.optimizer.state_dict())\n        # alternative way to cast per-param state tensors:\n        # self.optimizer.load_state_dict(init_state_dict)\n\n        if dynamic_loss_scale:\n            self.dynamic_loss_scale = True\n            if dynamic_loss_args is not None:\n                self.loss_scaler = LossScaler(\"dynamic\", **dynamic_loss_args)\n            else:\n                self.loss_scaler = LossScaler(\"dynamic\")\n        else:\n            self.dynamic_loss_scale = False\n            self.loss_scaler = LossScaler(static_loss_scale)\n\n        self.overflow = False\n        self.first_closure_call_this_step = True\n\n        self.clip_grad_norm = clip_grad_norm\n\n        # TODO:  Centralize exposure and import error checking for the C backend.\n        if multi_tensor_applier.available:\n            import amp_C\n            self.multi_tensor_scale = amp_C.multi_tensor_scale\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0]);\n\n    # Having self.maybe_print distinct from _amp_state.maybe_print is another artifact\n    # of having to support FP16_Optimizer separately, for the time being.\n    def maybe_print(self, msg):\n        if self.verbose:\n            print(msg)\n            \n    def __getstate__(self):\n        raise RuntimeError(\"FP16_Optimizer should be serialized using state_dict().\")\n\n    def __setstate__(self, state):\n        raise RuntimeError(\"FP16_Optimizer should be deserialized using load_state_dict().\")\n\n    def zero_grad(self, set_grads_to_None=False):\n        \"\"\"\n        Zero fp32 and fp16 parameter grads.\n        \"\"\"\n        # In principle, only the .grad attributes of the model params need to be zeroed,\n        # because gradients are copied into the FP32 master params.  However, we zero\n        # all gradients owned by the optimizer, just to be safe:\n        for group in self.optimizer.param_groups:\n             for p in group['params']:\n                 if set_grads_to_None:\n                     p.grad = None\n                 else:\n                     if p.grad is not None:\n                         p.grad.detach_()\n                         p.grad.zero_()\n\n        # Zero fp16 gradients owned by the model:\n        for fp16_group in self.fp16_groups:\n            for param in fp16_group:\n                if set_grads_to_None:\n                    param.grad = None\n                else:\n                    if param.grad is not None:\n                        param.grad.detach_() # as in torch.optim.optimizer.zero_grad()\n                        param.grad.zero_()\n\n    # Should not be used anymore.\n    # def _check_overflow(self):\n    #     params = []\n    #     for group in self.fp16_groups:\n    #         for param in group:\n    #             params.append(param)\n    #     for group in self.fp32_from_fp32_groups:\n    #         for param in group:\n    #             params.append(param)\n    #     self.overflow = self.loss_scaler.has_overflow(params)\n\n    # def _update_scale(self, has_overflow=False):\n    #     self.loss_scaler.update_scale(has_overflow)\n\n    def _master_params_to_model_params(self):\n        if multi_tensor_applier.available:\n            if len(self.all_fp16_params) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_scale,\n                    self._dummy_overflow_buf,\n                    [self.all_fp32_from_fp16_params, self.all_fp16_params],\n                    1.0)\n        else:\n            for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):\n                master_params_to_model_params(fp16_group, fp32_from_fp16_group)\n\n    # To consider:  Integrate distributed with this wrapper by registering a hook on each variable\n    # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.\n    # def _model_grads_to_master_grads(self):\n    #     for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):\n    #         model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)\n\n    # def _downscale_master(self):\n    #     if self.loss_scale != 1.0:\n    #         for group in self.optimizer.param_groups:\n    #             for param in group['params']:\n    #                 if param.grad is not None:\n    #                     param.grad.data.mul_(1./self.loss_scale)\n\n    def clip_master_grads(self, max_norm, norm_type=2):\n        \"\"\"\n        Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.\n\n        Args:\n            max_norm (float or int): max norm of the gradients\n            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n                infinity norm.\n\n        Returns:\n            Total norm of the current fp32 gradients (viewed as a single vector).\n\n        .. warning::\n            Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).\n        \"\"\"\n        if not self.overflow:\n            fp32_params = []\n            for param_group in self.optimizer.param_groups:\n                for param in param_group['params']:\n                    fp32_params.append(param)\n            return self.clip_grad_norm(fp32_params, max_norm, norm_type)\n        else:\n            return -1\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.\n        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict\n        of the contained Pytorch optimizer.\n        Example::\n\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        state_dict = {}\n        state_dict['loss_scaler'] = self.loss_scaler\n        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale\n        state_dict['overflow'] = self.overflow\n        state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step\n        state_dict['optimizer_state_dict'] = self.optimizer.state_dict()\n        state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict(). \n        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, \n        whose parameters in turn came from ``model``, it is expected that the user \n        will call ``model.load_state_dict()`` before\n        ``fp16_optimizer_instance.load_state_dict()`` is called.\n\n        Example::\n\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # I think it should actually be ok to reload the optimizer before the model.\n        self.loss_scaler = state_dict['loss_scaler']\n        self.dynamic_loss_scale = state_dict['dynamic_loss_scale']\n        self.overflow = state_dict['overflow']\n        self.first_closure_call_this_step = state_dict['first_closure_call_this_step']\n        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])\n        # At this point, the optimizer's references to the model's fp32 parameters are up to date.\n        # The optimizer's hyperparameters and internal buffers are also up to date.  \n        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still\n        # out of date.  There are two options.  \n        # 1:  Refresh the master params from the model's fp16 params.  \n        # This requires less storage but incurs precision loss.\n        # 2:  Save and restore the fp32 master copies separately.\n        # We choose option 2.\n        # \n        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device \n        # of their associated parameters, because it's possible those buffers might not exist yet in \n        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been \n        # constructed in the same way as the one whose state_dict we are loading, the same master params\n        # are guaranteed to exist, so we can just copy_() from the saved master params.\n        for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):\n            for current, saved in zip(current_group, saved_group):\n                current.data.copy_(saved.data)\n\n    def step(self, closure=None): # could add clip option.\n        \"\"\"\n        If no closure is supplied, :attr:`step` should be called after \n        ``fp16_optimizer_obj.backward(loss)``.\n        :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to\n        :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params\n        originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run\n        another forward pass using their model.\n\n        If a closure is supplied, :attr:`step` may be called without a prior call to \n        :attr:`backward(loss)`.\n        This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.\n        However, the user should take care that any ``loss.backward()`` call within the closure\n        has been replaced by ``fp16_optimizer_obj.backward(loss)``.\n\n        Args:\n           closure (optional):  Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor.  closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.\n\n        Example with closure::\n\n            # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an \n            # existing pytorch optimizer.\n            for input, target in dataset:\n                def closure():\n                    optimizer.zero_grad()\n                    output = model(input)\n                    loss = loss_fn(output, target)\n                    # loss.backward() becomes:\n                    optimizer.backward(loss)\n                    return loss\n                optimizer.step(closure)\n\n        .. warning::\n            Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.\n\n        .. _`ordinary Pytorch optimizer use`:\n            http://pytorch.org/docs/master/optim.html#optimizer-step-closure\n        \"\"\"\n\n        scale = self.loss_scaler.loss_scale()\n        # To consider:  Should this be in step(), or update_master_grads?  It works either way,\n        # but I should make it consistent with the Amp control flow, which updates the scale\n        # during backward context manager exit.\n        # self._update_scale(self.overflow)\n\n        if self.overflow:\n            # Using _amp_state.maybe_print instead of self.print here is intentional.\n            maybe_print(\"Gradient overflow.  Skipping step, reducing \" +\n                \"loss scale to {}\".format(self.loss_scaler.loss_scale()))\n            return\n        \n        if closure is not None:\n            retval = self._step_with_closure(closure)\n        else:\n            # torch.cuda.nvtx.range_push(\"pytorch optimizer step\")\n            retval = self.optimizer.step()\n            # torch.cuda.nvtx.range_pop()\n\n        self._master_params_to_model_params()\n\n        return retval\n\n    def _step_with_closure(self, closure):\n        def wrapped_closure():\n            # helpful for debugging\n            # print(\"Calling wrapped_closure, first_closure_call_this_step = {}\"\n            #       .format(self.first_closure_call_this_step))\n            if self.first_closure_call_this_step:\n                # We expect that the fp16 params are initially fresh on entering self.step(),\n                # so _master_params_to_model_params() is unnecessary the first time wrapped_closure()\n                # is called within self.optimizer.step().\n                self.first_closure_call_this_step = False\n            else:\n                # If self.optimizer.step() internally calls wrapped_closure more than once,\n                # it may update the fp32 params after each call.  However, self.optimizer \n                # doesn't know about the fp16 params at all.  If the fp32 params get updated,\n                # we can't rely on self.optimizer to refresh the fp16 params.  We need\n                # to handle that manually:\n                self._master_params_to_model_params()\n            # Our API expects the user to give us ownership of the backward() call by\n            # replacing all calls to loss.backward() with optimizer.backward(loss).\n            # This requirement holds whether or not the call to backward() is made within a closure.\n            # If the user is properly calling optimizer.backward(loss) within \"closure,\" \n            # calling closure() here will give the fp32 master params fresh gradients\n            # for the optimizer to play with, so all wrapped_closure needs to do is call \n            # closure() and return the loss.\n            temp_loss = closure() \n            while(self.overflow):\n                scale = self.loss_scaler.loss_scale()\n                # self._update_scale(self.overflow) # now done at the end of backward\n                print(\"OVERFLOW within closure! Skipping step, reducing loss scale to {}\".format(\n                      self.loss_scaler.loss_scale()))\n                temp_loss = closure()\n            return temp_loss\n\n        retval = self.optimizer.step(wrapped_closure)\n\n        self.first_closure_call_this_step = True\n\n        return retval\n\n    def backward(self, loss, update_master_grads=True, retain_graph=False):\n        \"\"\" \n        :attr:`backward` performs the following conceptual steps:\n\n        1. fp32_loss = loss.float() (see first Note below)\n        2. scaled_loss = fp32_loss*loss_scale\n        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).\n        4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.\n        5. Finally, master grads are divided by loss_scale.\n\n        In this way, after :attr:`backward`, the master params have fresh gradients,\n        and :attr:`step` may be called.\n\n        .. note::\n            :attr:`backward` internally converts the loss to fp32 before applying the loss scale.\n            This provides some additional safety against overflow if the user has supplied an \n            fp16 loss value.  \n            However, for maximum overflow safety, the user should\n            compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to \n            :attr:`backward`.\n\n        .. warning::\n            The gradients found in a model's leaves after the call to \n            :attr:`backward` should not be regarded as valid in general, \n            because it's possible \n            they have been scaled (and in the case of dynamic loss scaling, \n            the scale factor may change over time).  \n            If the user wants to inspect gradients after a call to :attr:`backward`,  \n            only the master gradients should be regarded as valid.  These can be retrieved via\n            :attr:`inspect_master_grad_data()`.\n\n        Args:\n            loss:  The loss output by the user's model.  loss may be either float or half (but see first Note above).\n            update_master_grads (bool, optional, default=True):  Option to copy fp16 grads to fp32 grads on this call.  By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration.  If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.\n            retain_graph (bool, optional, default=False):  Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``.  If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).\n\n        Example::\n\n            # Ordinary operation:\n            optimizer.backward(loss)\n\n            # Naive operation with multiple losses (technically valid, but less efficient):\n            # fp32 grads will be correct after the second call,  but \n            # the first call incurs an unnecessary fp16->fp32 grad copy.\n            optimizer.backward(loss1)\n            optimizer.backward(loss2)\n\n            # More efficient way to handle multiple losses:\n            # The fp16->fp32 grad copy is delayed until fp16 grads from all \n            # losses have been accumulated.\n            optimizer.backward(loss1, update_master_grads=False)\n            optimizer.backward(loss2, update_master_grads=False)\n            optimizer.update_master_grads()\n        \"\"\" \n        # To consider:  try multiple backward passes using retain_grad=True to find \n        # a loss scale that works.  After you find a loss scale that works, do a final dummy\n        # backward pass with retain_graph=False to tear down the graph.  Doing this would avoid \n        # discarding the iteration,  but probably wouldn't improve overall efficiency.  \n        scaled_loss = loss.float()*self.loss_scaler.loss_scale()\n        scaled_loss.backward(retain_graph=retain_graph)\n        if update_master_grads:\n            self.update_master_grads()\n\n    def update_master_grads(self):\n        # torch.cuda.nvtx.range_push(\"update_master_grads\")\n        \"\"\"\n        Copy the ``.grad`` attribute from stored references to fp16 parameters to \n        the ``.grad`` attribute of the fp32 master parameters that are directly \n        updated by the optimizer.  :attr:`update_master_grads` only needs to be called if\n        ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.\n        \"\"\"\n        # if self.dynamic_loss_scale:\n        #     self._check_overflow()\n        #     if self.overflow: return\n        # self._model_grads_to_master_grads()\n        # self._downscale_master()\n        # Use the one-shot multi-tensor apply kernel\n        self.loss_scaler.clear_overflow_state()\n        if len(self.all_fp16_params) > 0:\n            # print(\"Model grads before\")\n            # print([param.grad.data for param in self.all_fp16_params])\n            # I'm ONLY writing this as an incremental way to make some tests pass until\n            # I can refactor the tests as well.\n            # FP16_Optimizer should not be used by anyone.\n            model_grads = []\n            master_grads = []\n            for model_param, master_param in zip(self.all_fp16_params,\n                                                 self.all_fp32_from_fp16_params):\n                if model_param.grad is not None:\n                    model_grads.append(model_param.grad)\n                    if master_param.grad is None:\n                        master_param.grad = torch.empty_like(master_param)\n                    master_grads.append(master_param.grad)\n            self.loss_scaler.unscale(\n                model_grads,\n                master_grads,\n                self.loss_scaler.loss_scale())\n            # print(\"Master grads after\")\n            # print([param.grad.data for param in self.all_fp32_from_fp16_params])\n        if len(self.all_fp32_from_fp32_params) > 0:\n            model_grads = []\n            master_grads = []\n            for model_param, master_param in zip(self.all_fp32_from_fp32_params,\n                                                 self.all_fp32_from_fp32_params):\n                if model_param.grad is not None:\n                    model_grads.append(model_param.grad)\n                    master_grads.append(master_param.grad)\n            # print(\"Model grads before\")\n            # print([param.grad.data for param in self.all_fp32_from_fp32_params])\n            self.loss_scaler.unscale(\n                model_grads,\n                master_grads,\n                self.loss_scaler.loss_scale())\n            # print(\"Master grads after\")\n            # print([param.grad.data for param in self.all_fp32_from_fp32_params])\n        # quit()\n        self.overflow = self.loss_scaler.update_scale()\n        # torch.cuda.nvtx.range_pop()\n\n\n    def inspect_master_grad_data(self):\n        \"\"\"\n        When running with :class:`FP16_Optimizer`, \n        ``.grad`` attributes of a model's fp16 leaves should not be\n        regarded as truthful, because they might be scaled.  \n        After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,\n        the fp32 master params' ``.grad``\n        attributes will contain valid gradients properly divided by the loss scale.  However, \n        because :class:`FP16_Optimizer` flattens some parameters, accessing them may be \n        nonintuitive.  :attr:`inspect_master_grad_data`\n        allows those gradients to be viewed with shapes corresponding to their associated model leaves.\n\n        Returns:\n            List of lists (one list for each parameter group).  The list for each parameter group\n            is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.                 \n        \"\"\"\n        if self.overflow:\n            print(\"Warning:  calling FP16_Optimizer.inspect_master_grad_data while in an overflow state.  \"\n                  \"Gradients are currently invalid (may be inf, nan, or stale).  Returning None.\")\n            return None\n        else:\n            # The optimizer owns only references to master params.\n            master_grads_data = []\n            for param_group in self.optimizer.param_groups:\n                master_grads_this_group = []\n                for param in param_group['params']:\n                    if param.grad is not None:\n                        master_grads_this_group.append(param.grad.data)\n                    else:\n                        master_grads_this_group.append(None)\n                master_grads_data.append(master_grads_this_group)\n            return master_grads_data\n\n\n    # Promote loss scale so it can be retrieved or set via \"fp16_optimizer_instance.loss_scale\"\n    def _get_loss_scale(self):\n        return self.loss_scaler.loss_scale()\n\n    def _set_loss_scale(self, value):\n        self.loss_scaler._loss_scale = value\n\n    loss_scale = property(_get_loss_scale, _set_loss_scale)\n\n    # Promote state so it can be retrieved or set via \"fp16_optimizer_instance.state\"\n    def _get_state(self):\n        return self.optimizer.state\n\n    def _set_state(self, value):\n        self.optimizer.state = value\n\n    state = property(_get_state, _set_state)\n\n    # Promote param_groups so it can be retrieved or set via \"fp16_optimizer_instance.param_groups\"\n    # (for example, to adjust the learning rate)\n    def _get_param_groups(self):\n        return self.optimizer.param_groups\n\n    def _set_param_groups(self, value):\n        self.optimizer.param_groups = value\n\n    param_groups = property(_get_param_groups, _set_param_groups)\n\n"
  },
  {
    "path": "KoSentenceT5/apex/fp16_utils/fp16util.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\n\n\nclass tofp16(nn.Module):\n    \"\"\"\n    Utility module that implements::\n\n        def forward(self, input):\n            return input.half()\n    \"\"\"\n\n    def __init__(self):\n        super(tofp16, self).__init__()\n\n    def forward(self, input):\n        return input.half()\n\n\ndef BN_convert_float(module):\n    \"\"\"\n    Utility function for network_to_half().\n\n    Retained for legacy purposes.\n    \"\"\"\n    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:\n        module.float()\n    for child in module.children():\n        BN_convert_float(child)\n    return module\n\n\ndef network_to_half(network):\n    \"\"\"\n    Convert model to half precision in a batchnorm-safe way.\n\n    Retained for legacy purposes. It is recommended to use FP16Model.\n    \"\"\"\n    return nn.Sequential(tofp16(), BN_convert_float(network.half()))\n\n\ndef convert_module(module, dtype):\n    \"\"\"\n    Converts a module's immediate parameters and buffers to dtype.\n    \"\"\"\n    for param in module.parameters(recurse=False):\n        if param is not None:\n            if param.data.dtype.is_floating_point:\n                param.data = param.data.to(dtype=dtype)\n            if param._grad is not None and param._grad.data.dtype.is_floating_point:\n                param._grad.data = param._grad.data.to(dtype=dtype)\n\n    for buf in module.buffers(recurse=False):\n        if buf is not None and buf.data.dtype.is_floating_point:\n            buf.data = buf.data.to(dtype=dtype)\n\n\ndef convert_network(network, dtype):\n    \"\"\"\n    Converts a network's parameters and buffers to dtype.\n    \"\"\"\n    for module in network.modules():\n        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:\n            continue\n        convert_module(module, dtype)\n        if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):\n            module.flatten_parameters()\n    return network\n\n\nclass FP16Model(nn.Module):\n    \"\"\"\n    Convert model to half precision in a batchnorm-safe way.\n    \"\"\"\n\n    def __init__(self, network):\n        super(FP16Model, self).__init__()\n        self.network = convert_network(network, dtype=torch.half)\n\n    def forward(self, *inputs):\n        inputs = tuple(t.half() for t in inputs)\n        return self.network(*inputs)\n\n\ndef backwards_debug_hook(grad):\n    raise RuntimeError(\"master_params recieved a gradient in the backward pass!\")\n\ndef prep_param_lists(model, flat_master=False):\n    \"\"\"\n    Creates a list of FP32 master parameters for a given model, as in\n    `Training Neural Networks with Mixed Precision:  Real Examples`_.\n\n    Args:\n        model (torch.nn.Module): Existing Pytorch model\n        flat_master (bool, optional, default=False):  Flatten the master parameters into a single tensor, as a performance optimization.\n    Returns:\n        A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`.  ``master_params`` is a list of FP32 master gradients.  If ``flat_master=True``, ``master_params`` will be a list with one element.\n\n    Example::\n\n        model_params, master_params = prep_param_lists(model)\n\n    .. warning::\n        Currently, if ``flat_master=True``, all the model's parameters must be the same type.  If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.\n\n    .. _`Training Neural Networks with Mixed Precision:  Real Examples`:\n        http://on-demand.gputechconf.com/gtc/2018/video/S81012/\n    \"\"\"\n    model_params = [param for param in model.parameters() if param.requires_grad]\n\n    if flat_master:\n        # Give the user some more useful error messages\n        try:\n            # flatten_dense_tensors returns a contiguous flat array.\n            # http://pytorch.org/docs/master/_modules/torch/_utils.html\n            master_params = _flatten_dense_tensors([param.data for param in model_params]).float()\n        except:\n            print(\"Error in prep_param_lists:  model may contain a mixture of parameters \"\n                      \"of different types.  Use flat_master=False, or use F16_Optimizer.\")\n            raise\n        master_params = torch.nn.Parameter(master_params)\n        master_params.requires_grad = True\n        # master_params.register_hook(backwards_debug_hook)\n        if master_params.grad is None:\n            master_params.grad = master_params.new(*master_params.size())\n        return model_params, [master_params]\n    else:\n        master_params = [param.clone().float().detach() for param in model_params]\n        for param in master_params:\n            param.requires_grad = True\n        return model_params, master_params\n\n\ndef model_grads_to_master_grads(model_params, master_params, flat_master=False):\n    \"\"\"\n    Copy model gradients to master gradients.  \n\n    Args:\n        model_params:  List of model parameters created by :func:`prep_param_lists`.\n        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.\n    \"\"\"\n    if flat_master:\n        # The flattening may incur one more deep copy than is necessary.\n        master_params[0].grad.data.copy_(\n            _flatten_dense_tensors([p.grad.data for p in model_params]))\n    else:\n        for model, master in zip(model_params, master_params):\n            if model.grad is not None:\n                if master.grad is None:\n                    master.grad = Variable(master.data.new(*master.data.size()))\n                master.grad.data.copy_(model.grad.data)\n            else:\n                master.grad = None\n\n\ndef master_params_to_model_params(model_params, master_params, flat_master=False):\n    \"\"\"\n    Copy master parameters to model parameters.\n\n    Args:\n        model_params:  List of model parameters created by :func:`prep_param_lists`.\n        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.\n    \"\"\"\n    if flat_master:\n        for model, master in zip(model_params, \n                                 _unflatten_dense_tensors(master_params[0].data, model_params)):\n            model.data.copy_(master)\n    else:\n        for model, master in zip(model_params, master_params):\n            model.data.copy_(master.data)\n\n# Backward compatibility fixes\n\ndef to_python_float(t):\n    if hasattr(t, 'item'):\n        return t.item()\n    else:\n        return t[0]\n\nTORCH_MAJOR = int(torch.__version__.split('.')[0])\nTORCH_MINOR = int(torch.__version__.split('.')[1])\nif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:\n    clip_grad_norm = torch.nn.utils.clip_grad_norm\nelse:\n    clip_grad_norm = torch.nn.utils.clip_grad_norm_\n"
  },
  {
    "path": "KoSentenceT5/apex/fp16_utils/loss_scaler.py",
    "content": "import torch\n\n# item() is a recent addition, so this helps with backward compatibility.\ndef to_python_float(t):\n    if hasattr(t, 'item'):\n        return t.item()\n    else:\n        return t[0]\n\nclass LossScaler:\n    \"\"\"\n    Class that manages a static loss scale.  This class is intended to interact with\n    :class:`FP16_Optimizer`, and should not be directly manipulated by the user.\n\n    Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to \n    :class:`FP16_Optimizer`'s constructor.\n\n    Args:\n        scale (float, optional, default=1.0):  The loss scale.\n    \"\"\"\n\n    def __init__(self, scale=1):\n        self.cur_scale = scale\n\n    # `params` is a list / generator of torch.Variable\n    def has_overflow(self, params):\n        return False\n\n    # `x` is a torch.Tensor\n    def _has_inf_or_nan(x):\n        return False\n\n    def update_scale(self, overflow):\n        pass\n\n    @property\n    def loss_scale(self):\n        return self.cur_scale\n\n    def scale_gradient(self, module, grad_in, grad_out):\n        return tuple(self.loss_scale * g for g in grad_in)\n\n    def backward(self, loss, retain_graph=False):\n        scaled_loss = loss*self.loss_scale\n        scaled_loss.backward(retain_graph=retain_graph)\n\nclass DynamicLossScaler:\n    \"\"\"\n    Class that manages dynamic loss scaling.  It is recommended to use :class:`DynamicLossScaler`\n    indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of \n    :class:`FP16_Optimizer`.  However, it's important to understand how :class:`DynamicLossScaler`\n    operates, because the default options can be changed using the\n    the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.\n\n    Loss scaling is designed to combat the problem of underflowing gradients encountered at long\n    times when training fp16 networks.  Dynamic loss scaling begins by attempting a very high loss\n    scale.  Ironically, this may result in OVERflowing gradients.  If overflowing gradients are\n    encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has \n    occurred.\n    :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,\n    and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.  \n    If a certain number of iterations occur without overflowing gradients detected,\n    :class:`DynamicLossScaler` increases the loss scale once more.\n    In this way :class:`DynamicLossScaler` attempts to \"ride the edge\" of \n    always using the highest loss scale possible without incurring overflow.\n\n    Args:\n        init_scale (float, optional, default=2**32):  Initial loss scale attempted by :class:`DynamicLossScaler.`\n        scale_factor (float, optional, default=2.0):  Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``.  If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. \n        scale_window (int, optional, default=1000):  Number of consecutive iterations without an overflow to wait before increasing the loss scale.\n    \"\"\"\n\n    def __init__(self,\n                 init_scale=2**32,\n                 scale_factor=2.,\n                 scale_window=1000):\n        self.cur_scale = init_scale\n        self.cur_iter = 0\n        self.last_overflow_iter = -1\n        self.scale_factor = scale_factor\n        self.scale_window = scale_window\n\n    # `params` is a list / generator of torch.Variable\n    def has_overflow(self, params):\n        for p in params:\n            if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):\n                return True\n\n        return False\n\n    # `x` is a torch.Tensor\n    def _has_inf_or_nan(x):\n        try:\n            # if x is half, the .float() incurs an additional deep copy, but it's necessary if \n            # Pytorch's .sum() creates a one-element tensor of the same type as x \n            # (which is true for some recent version of pytorch).\n            cpu_sum = float(x.float().sum())\n            # More efficient version that can be used if .sum() returns a Python scalar\n            # cpu_sum = float(x.sum())\n        except RuntimeError as instance:\n            # We want to check if inst is actually an overflow exception.\n            # RuntimeError could come from a different error.\n            # If so, we still want the exception to propagate.\n            if \"value cannot be converted\" not in instance.args[0]:\n                raise\n            return True\n        else:\n            if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n                return True\n            return False\n\n    # `overflow` is boolean indicating whether the gradient overflowed\n    def update_scale(self, overflow):\n        if overflow:\n            # self.cur_scale /= self.scale_factor\n            self.cur_scale = max(self.cur_scale/self.scale_factor, 1)\n            self.last_overflow_iter = self.cur_iter\n        else:\n            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:\n                self.cur_scale *= self.scale_factor\n        self.cur_iter += 1\n\n    @property\n    def loss_scale(self):\n        return self.cur_scale\n\n    def scale_gradient(self, module, grad_in, grad_out):\n        return tuple(self.loss_scale * g for g in grad_in)\n\n    def backward(self, loss, retain_graph=False):\n        scaled_loss = loss*self.loss_scale\n        scaled_loss.backward(retain_graph=retain_graph)\n        \n##############################################################        \n# Example usage below here -- assuming it's in a separate file\n##############################################################\n\"\"\"\nTO-DO separate out into an example.\nif __name__ == \"__main__\":\n    import torch\n    from torch.autograd import Variable\n    from dynamic_loss_scaler import DynamicLossScaler\n\n    # N is batch size; D_in is input dimension;\n    # H is hidden dimension; D_out is output dimension.\n    N, D_in, H, D_out = 64, 1000, 100, 10\n\n    # Create random Tensors to hold inputs and outputs, and wrap them in Variables.\n    x = Variable(torch.randn(N, D_in), requires_grad=False)\n    y = Variable(torch.randn(N, D_out), requires_grad=False)\n\n    w1 = Variable(torch.randn(D_in, H), requires_grad=True)\n    w2 = Variable(torch.randn(H, D_out), requires_grad=True)\n    parameters = [w1, w2]\n\n    learning_rate = 1e-6\n    optimizer = torch.optim.SGD(parameters, lr=learning_rate)\n    loss_scaler = DynamicLossScaler()\n\n    for t in range(500):\n        y_pred = x.mm(w1).clamp(min=0).mm(w2)\n        loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale\n        print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))\n        print('Iter {} scaled loss: {}'.format(t, loss.data[0]))\n        print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))\n\n        # Run backprop\n        optimizer.zero_grad()\n        loss.backward()\n        \n        # Check for overflow\n        has_overflow = DynamicLossScaler.has_overflow(parameters)\n        \n        # If no overflow, unscale grad and update as usual\n        if not has_overflow:\n            for param in parameters:\n                param.grad.data.mul_(1. / loss_scaler.loss_scale)\n            optimizer.step()\n        # Otherwise, don't do anything -- ie, skip iteration\n        else:\n            print('OVERFLOW!')\n\n        # Update loss scale for next iteration\n        loss_scaler.update_scale(has_overflow)\n\n\"\"\"\n"
  },
  {
    "path": "KoSentenceT5/apex/mlp/__init__.py",
    "content": "from .mlp import *\n"
  },
  {
    "path": "KoSentenceT5/apex/mlp/mlp.py",
    "content": "from copy import copy\nimport math\nimport torch\nfrom torch import nn\nimport mlp_cuda\nfrom .. import amp\n\nclass MlpFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, bias, activation, *args):\n        output = mlp_cuda.forward(bias, activation, args)\n        ctx.save_for_backward(*args)\n        ctx.outputs = output\n        ctx.bias = bias\n        ctx.activation = activation\n        return output[0]\n\n    @staticmethod\n    def backward(ctx, grad_o):\n        grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)\n        del ctx.outputs\n        return (None, None, *grads)\n\nmlp_function = amp.half_function(MlpFunction.apply)\n\nclass MLP(torch.nn.Module):\n    \"\"\"Launch MLP in C++\n\n    Args:\n        mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024\n        bias (bool): Default True:\n        relu (bool): Default True\n    \"\"\"\n    def __init__(self, mlp_sizes, bias=True, activation='relu'):\n        super(MLP, self).__init__()\n        self.num_layers = len(mlp_sizes) - 1\n        self.mlp_sizes = copy(mlp_sizes)\n        self.bias = 1 if bias else 0\n\n        if activation is 'none':\n            self.activation = 0\n        elif activation is 'relu':\n            self.activation = 1\n        elif activation is 'sigmoid':\n            self.activation = 2\n        else:\n            raise TypeError(\"activation must be relu or none.\")\n\n        self.weights = []\n        self.biases = []\n        for i in range(self.num_layers):\n            w = torch.nn.Parameter(torch.empty(mlp_sizes[i+1], mlp_sizes[i]))\n            self.weights.append(w)\n            name = 'weight_{}'.format(i)\n            setattr(self, name, w)\n            if self.bias:\n                b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))\n                self.biases.append(b)\n                name = 'bias_{}'.format(i)\n                setattr(self, name, b)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for weight in self.weights:\n            dimsum = weight.size(0) + weight.size(1)\n            std = math.sqrt(2. / float(dimsum))\n            nn.init.normal_(weight, 0., std)\n        if self.bias:\n            for bias in self.biases:\n                std = math.sqrt(1. / float(bias.size(0)))\n                nn.init.normal_(bias, 0., std)\n\n    def forward(self, input):\n        return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)\n\n    def extra_repr(self):\n        s = F\"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}\"\n        return s\n"
  },
  {
    "path": "KoSentenceT5/apex/multi_tensor_apply/__init__.py",
    "content": "from .multi_tensor_apply import MultiTensorApply\n\nmulti_tensor_applier = MultiTensorApply(2048*32)\n\n"
  },
  {
    "path": "KoSentenceT5/apex/multi_tensor_apply/multi_tensor_apply.py",
    "content": "import torch\n\nclass MultiTensorApply(object):\n    available = False\n    warned = False\n\n    def __init__(self, chunk_size):\n        try:\n            import amp_C\n            MultiTensorApply.available = True\n            self.chunk_size = chunk_size\n        except ImportError as err:\n            MultiTensorApply.available = False\n            MultiTensorApply.import_err = err\n\n    def check_avail(self):\n        if MultiTensorApply.available == False:\n            raise RuntimeError(\n                \"Attempted to call MultiTensorApply method, but MultiTensorApply \"\n                \"is not available, possibly because Apex was installed without \"\n                \"--cpp_ext --cuda_ext.  Original import error message:\",\n                MultiTensorApply.import_err)\n\n    def __call__(self, op, noop_flag_buffer, tensor_lists, *args):\n        self.check_avail()\n\n        return op(self.chunk_size,\n                  noop_flag_buffer,\n                  tensor_lists,\n                  *args)\n"
  },
  {
    "path": "KoSentenceT5/apex/normalization/__init__.py",
    "content": "from .fused_layer_norm import FusedLayerNorm\n"
  },
  {
    "path": "KoSentenceT5/apex/normalization/fused_layer_norm.py",
    "content": "import math\nimport torch\nimport numbers\nfrom torch.nn.parameter import Parameter\nfrom torch.nn import init\nfrom torch.nn import functional as F\nimport importlib\n\nglobal fused_layer_norm_cuda\nfused_layer_norm_cuda = None\n\nclass FusedLayerNormAffineFunction(torch.autograd.Function):\n\n  @staticmethod\n  def forward(ctx, input, weight, bias, normalized_shape, eps):\n    global fused_layer_norm_cuda\n    if fused_layer_norm_cuda is None:\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n    ctx.normalized_shape = normalized_shape\n    ctx.eps = eps\n    input_ = input.contiguous()\n    weight_ = weight.contiguous()\n    bias_ = bias.contiguous()\n    output, mean, invvar = fused_layer_norm_cuda.forward_affine(\n        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)\n    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)\n    return output\n\n  @staticmethod\n  def backward(ctx, grad_output):\n    input_, weight_, bias_, mean, invvar = ctx.saved_tensors\n    grad_input = grad_weight = grad_bias = None\n    grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(\n        grad_output.contiguous(), mean, invvar,\n        input_, ctx.normalized_shape,\n        weight_, bias_, ctx.eps)\n    return grad_input, grad_weight, grad_bias, None, None\n\nclass FusedLayerNormFunction(torch.autograd.Function):\n\n  @staticmethod\n  def forward(ctx, input, normalized_shape, eps):\n    global fused_layer_norm_cuda\n    if fused_layer_norm_cuda is None:\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n    ctx.normalized_shape = normalized_shape\n    ctx.eps = eps\n    input_ = input.contiguous()\n    output, mean, invvar = fused_layer_norm_cuda.forward(\n        input_, ctx.normalized_shape, ctx.eps)\n    ctx.save_for_backward(input_, mean, invvar)\n    return output\n\n  @staticmethod\n  def backward(ctx, grad_output):\n    input_, mean, invvar = ctx.saved_tensors\n    grad_input = None\n    grad_input = fused_layer_norm_cuda.backward(\n        grad_output.contiguous(), mean, invvar,\n        input_, ctx.normalized_shape,\n        ctx.eps)\n    return grad_input, None, None\n\ndef fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):\n    return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)\n\ndef fused_layer_norm(input, normalized_shape, eps=1e-6):\n    return FusedLayerNormFunction.apply(input, normalized_shape, eps)\n\nclass FusedLayerNorm(torch.nn.Module):\n    r\"\"\"Applies Layer Normalization over a mini-batch of inputs as described in\n    the paper `Layer Normalization`_ .\n\n    Currently only runs on cuda() tensors.\n\n    .. math::\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\n\n    The mean and standard-deviation are calculated separately over the last\n    certain number dimensions which have to be of the shape specified by\n    :attr:`normalized_shape`.\n    :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of\n    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.\n\n    .. note::\n        Unlike Batch Normalization and Instance Normalization, which applies\n        scalar scale and bias for each entire channel/plane with the\n        :attr:`affine` option, Layer Normalization applies per-element scale and\n        bias with :attr:`elementwise_affine`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    Args:\n        normalized_shape (int or list or torch.Size): input shape from an expected input\n            of size\n\n            .. math::\n                [* \\times \\text{normalized}\\_\\text{shape}[0] \\times \\text{normalized}\\_\\text{shape}[1]\n                    \\times \\ldots \\times \\text{normalized}\\_\\text{shape}[-1]]\n\n            If a single integer is used, it is treated as a singleton list, and this module will\n            normalize over the last dimension which is expected to be of that specific size.\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        elementwise_affine: a boolean value that when set to ``True``, this module\n            has learnable per-element affine parameters initialized to ones (for weights)\n            and zeros (for biases). Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n\n    Examples::\n\n        >>> input = torch.randn(20, 5, 10, 10)\n        >>> # With Learnable Parameters\n        >>> m = apex.normalization.FusedLayerNorm(input.size()[1:])\n        >>> # Without Learnable Parameters\n        >>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)\n        >>> # Normalize over last two dimensions\n        >>> m = apex.normalization.FusedLayerNorm([10, 10])\n        >>> # Normalize over last dimension of size 10\n        >>> m = apex.normalization.FusedLayerNorm(10)\n        >>> # Activating the module\n        >>> output = m(input)\n\n    .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450\n    \"\"\"\n    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n        super(FusedLayerNorm, self).__init__()\n\n        global fused_layer_norm_cuda\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.eps = eps\n        self.elementwise_affine = elementwise_affine\n        if self.elementwise_affine:\n            self.weight = Parameter(torch.Tensor(*normalized_shape))\n            self.bias = Parameter(torch.Tensor(*normalized_shape))\n        else:\n            self.register_parameter('weight', None)\n            self.register_parameter('bias', None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.elementwise_affine:\n            init.ones_(self.weight)\n            init.zeros_(self.bias)\n\n    def forward(self, input):\n        if not input.is_cuda:\n            return  F.layer_norm(\n                input, self.normalized_shape, self.weight, self.bias, self.eps)\n        if self.elementwise_affine:\n          return FusedLayerNormAffineFunction.apply(\n              input, self.weight, self.bias, self.normalized_shape,self.eps)\n        else:\n          return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)\n\n    def extra_repr(self):\n        return '{normalized_shape}, eps={eps}, ' \\\n            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)\n"
  },
  {
    "path": "KoSentenceT5/apex/optimizers/__init__.py",
    "content": "from .fused_sgd import FusedSGD\nfrom .fused_adam import FusedAdam\nfrom .fused_novograd import FusedNovoGrad\nfrom .fused_lamb import FusedLAMB\nfrom .fused_adagrad import FusedAdagrad"
  },
  {
    "path": "KoSentenceT5/apex/optimizers/fused_adagrad.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedAdagrad(torch.optim.Optimizer):\n    \"\"\"Implements Adagrad algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused Adagrad implements 2 fusions.\n      * Fusion of the Adagrad update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedAdagrad`'s usage is identical to any ordinary Pytorch optimizer::\n        opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedAdagrad` may be used with or without Amp.  If you wish to use :class:`FusedAdagrad` with Amp,\n    you may choose any ``opt_level``::\n        opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    It has been proposed in `Adaptive Subgradient Methods for Online Learning\n    and Stochastic Optimization`_.\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-2)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-10)\n        adagrad_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay (also known as AdamW) (default: False)\n\n    .. _Adaptive Subgradient Methods for Online Learning and Stochastic\n        Optimization: http://jmlr.org/papers/v12/duchi11a.html\n    \"\"\"\n    def __init__(self, params, lr=1e-2, eps=1e-10,\n                 weight_decay=0., set_grad_none=True, adagrad_w_mode=False):\n\n        defaults = dict(lr=lr, eps=eps, weight_decay=weight_decay)\n        super(FusedAdagrad, self).__init__(params, defaults)\n        self.adagrad_w_mode = 1 if adagrad_w_mode else 0\n        self.set_grad_none = set_grad_none\n\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_adagrad = amp_C.multi_tensor_adagrad\n        else:\n            raise RuntimeError('apex.optimizers.FusedAdagrad requires cuda extensions')\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedAdagrad, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            # create lists for multi-tensor apply\n            g_16, p_16, h_16 = [], [], []\n            g_32, p_32, h_32 = [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedAdagrad does not support sparse gradients')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['sum'] = torch.zeros_like(p.data)\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    h_16.append(state['sum'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    h_32.append(state['sum'])\n                else:\n                    raise RuntimeError('FusedAdagrad only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_adagrad,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, h_16],\n                                     group['lr'],\n                                     group['eps'],\n                                     self.adagrad_w_mode,\n                                     group['weight_decay'])\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_adagrad,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, h_32],\n                                     group['lr'],\n                                     group['eps'],\n                                     self.adagrad_w_mode,\n                                     group['weight_decay'])\n\n        return loss"
  },
  {
    "path": "KoSentenceT5/apex/optimizers/fused_adam.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedAdam(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused Adam implements 2 fusions.\n\n      * Fusion of the Adam update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,\n    or ``torch.optim.Adam`` with ``adam_w_mode=False``::\n\n        opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedAdam` may be used with or without Amp.  If you wish to use :class:`FusedAdam` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n\n    .. warning::\n        A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``.  These additional arguments\n        are now deprecated and unnecessary.\n\n    Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n\n    .. _Adam - A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,\n                 weight_decay=0., amsgrad=False, set_grad_none=True):\n\n        if amsgrad:\n            raise RuntimeError('FusedAdam does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay)\n        super(FusedAdam, self).__init__(params, defaults)\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_adam = amp_C.multi_tensor_adam\n        else:\n            raise RuntimeError('apex.optimizers.FusedAdam requires cuda extensions')\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedAdam, self).zero_grad()\n\n    def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n\n        The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.\n        \"\"\"\n        if any(p is not None for p in [grads, output_params, scale, grad_norms]):\n            raise RuntimeError('FusedAdam has been updated.  Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                    v_16.append(state['exp_avg_sq'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                    v_32.append(state['exp_avg_sq'])\n                else:\n                    raise RuntimeError('FusedAdam only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_adam,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16, v_16],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     self.adam_w_mode,\n                                     bias_correction,\n                                     group['weight_decay'])\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_adam,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32, v_32],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     self.adam_w_mode,\n                                     bias_correction,\n                                     group['weight_decay'])\n\n\n        return loss\n"
  },
  {
    "path": "KoSentenceT5/apex/optimizers/fused_lamb.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedLAMB(torch.optim.Optimizer):\n\n    \"\"\"Implements LAMB algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused LAMB implements 2 fusions.\n\n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n        use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0\n            weight decay parameter (default: False)\n\n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,\n                 amsgrad=False, adam_w_mode=True,\n                 grad_averaging=True, set_grad_none=True,\n                 max_grad_norm=1.0, use_nvlamb=False):\n        if amsgrad:\n            raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging,\n                        max_grad_norm=max_grad_norm)\n        super(FusedLAMB, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device)\n            self.multi_tensor_lamb = amp_C.multi_tensor_lamb\n        else:\n            raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')\n\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n        self.use_nvlamb = use_nvlamb\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedLAMB, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # create separate grad lists for fp32 and fp16 params\n        g_all_32, g_all_16 = [], []\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.dtype == torch.float32:\n                    g_all_32.append(p.grad.data)\n                elif p.dtype == torch.float16:\n                    g_all_16.append(p.grad.data)\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n        device = self.param_groups[0][\"params\"][0].device\n        g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)\n        # compute grad norm for two lists\n        if len(g_all_32) > 0:\n            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_32], False)[0]\n        if len(g_all_16) > 0:\n            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_16], False)[0]\n\n        # blend two grad norms to get global grad norm\n        global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,\n                                                self._dummy_overflow_buf,\n                                                [[g_norm_32, g_norm_16]],\n                                                False)[0]\n        max_grad_norm = self.defaults['max_grad_norm']\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n            grad_averaging = 1 if group['grad_averaging'] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                    v_16.append(state['exp_avg_sq'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                    v_32.append(state['exp_avg_sq'])\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16, v_16],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm,\n                                     self.use_nvlamb)\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32, v_32],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm,\n                                     self.use_nvlamb)\n\n        return loss\n"
  },
  {
    "path": "KoSentenceT5/apex/optimizers/fused_novograd.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedNovoGrad(torch.optim.Optimizer):\n\n    \"\"\"Implements NovoGrad algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused NovoGrad implements 2 fusions.\n\n      * Fusion of the NovoGrad update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedNovoGrad`'s usage is identical to any Pytorch optimizer::\n\n        opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedNovoGrad` may be used with or without Amp.  If you wish to use :class:`FusedNovoGrad` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    It has been proposed in `Jasper: An End-to-End Convolutional Neural Acoustic Model`_.\n    More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        reg_inside_moment (bool, optional): whether do regularization (norm and L2)\n            in momentum calculation. True for include, False for not include and\n            only do it on update term. (default: False)\n        grad_averaging (bool, optional): whether apply (1-beta1) to grad when\n            calculating running averages of gradient. (default: True)\n        norm_type (int, optional): which norm to calculate for each layer.\n            2 for L2 norm, and 0 for infinite norm. These 2 are only supported\n            type now. (default: 2)\n        init_zero (bool, optional): whether init norm with 0 (start averaging on\n            1st step) or first step norm (start averaging on 2nd step). True for\n            init with 0. (default: False)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n\n    .. _Jasper - An End-to-End Convolutional Neural Acoustic Model:\n        https://arxiv.org/abs/1904.03288\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-8, weight_decay=0.,\n                 amsgrad=False, reg_inside_moment=False,\n                 grad_averaging=True, norm_type=2, init_zero=False,\n                 set_grad_none=True):\n        if amsgrad:\n            raise RuntimeError('FusedNovoGrad does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging, norm_type=norm_type,\n                        init_zero=init_zero)\n        super(FusedNovoGrad, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n\n            # Creating the overflow buffer on the same device as the params tensors.\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device)\n            self.multi_tensor_novograd = amp_C.multi_tensor_novograd\n        else:\n            raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions')\n\n        self.moment_mode = 0 if reg_inside_moment else 1\n        self.set_grad_none = set_grad_none\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedNovoGrad, self).zero_grad()\n\n    def load_state_dict(self, state_dict):\n        super(FusedNovoGrad, self).load_state_dict(state_dict)\n        # in case exp_avg_sq is not on the same device as params, move it there\n        for group in self.param_groups:\n            if len(group['params']) > 0:\n                group['exp_avg_sq'][0] = group['exp_avg_sq'][0].to(group['params'][0].device)\n                group['exp_avg_sq'][1] = group['exp_avg_sq'][1].to(group['params'][0].device)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n            grad_averaging = 1 if group['grad_averaging'] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16 = [], [], []\n            g_32, p_32, m_32 = [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                else:\n                    raise RuntimeError('FusedNovoGrad only support fp16 and fp32.')\n\n            # we store per weight norm as one tensor for one group/precision combination\n            # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types\n            if 'exp_avg_sq' not in group:\n                group['exp_avg_sq'] = [None, None]\n                if group['init_zero']:\n                    # Creating the following parameters on the same device as the params tensors.\n                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16), device=self.param_groups[0][\"params\"][0].device).contiguous().fill_(0)\n                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32), device=self.param_groups[0][\"params\"][0].device).contiguous().fill_(0)\n                else: # init with first step norm, so first blend have no effect\n                    if group['norm_type'] == 0:\n                        v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]\n                        v_32 = [torch.max(torch.abs(g)).item() for g in g_32]\n                    elif group['norm_type'] == 2:\n                        v_16 = [torch.sum(torch.pow(g.to(torch.float32), 2)).sqrt().item() for g in g_16]\n                        v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]\n                    else:\n                        raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')\n                    # Creating the following parameters on the same device as the params tensors.\n                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16, device=self.param_groups[0][\"params\"][0].device)\n                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32, device=self.param_groups[0][\"params\"][0].device)\n            else:\n                assert(len(g_16) == group['exp_avg_sq'][0].numel())\n                assert(len(g_32) == group['exp_avg_sq'][1].numel())\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_novograd,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16],\n                                     group['exp_avg_sq'][0],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.moment_mode,\n                                     group['norm_type'])\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_novograd,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32],\n                                     group['exp_avg_sq'][1],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.moment_mode,\n                                     group['norm_type'])\n\n\n        return loss\n"
  },
  {
    "path": "KoSentenceT5/apex/optimizers/fused_sgd.py",
    "content": "import torch\nfrom torch.optim.optimizer import Optimizer, required\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedSGD(Optimizer):\n    r\"\"\"Implements stochastic gradient descent (optionally with momentum).\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused SGD implements 2 fusions.\n\n      * Fusion of the SGD update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``::\n\n        opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedSGD` may be used with or without Amp.  If you wish to use :class:`FusedSGD` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    Nesterov momentum is based on the formula from\n    `On the importance of initialization and momentum in deep learning`__.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float): learning rate\n        momentum (float, optional): momentum factor (default: 0)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        dampening (float, optional): dampening for momentum (default: 0)\n        nesterov (bool, optional): enables Nesterov momentum (default: False)\n\n    Example:\n        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n        >>> optimizer.zero_grad()\n        >>> loss_fn(model(input), target).backward()\n        >>> optimizer.step()\n\n    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf\n\n    .. note::\n        The implementation of SGD with Momentum/Nesterov subtly differs from\n        Sutskever et. al. and implementations in some other frameworks.\n\n        Considering the specific case of Momentum, the update can be written as\n\n        .. math::\n                  v = \\rho * v + g \\\\\n                  p = p - lr * v\n\n        where p, g, v and :math:`\\rho` denote the parameters, gradient,\n        velocity, and momentum respectively.\n\n        This is in contrast to Sutskever et. al. and\n        other frameworks which employ an update of the form\n\n        .. math::\n             v = \\rho * v + lr * g \\\\\n             p = p - v\n\n        The Nesterov version is analogously modified.\n    \"\"\"\n\n    def __init__(self, params, lr=required, momentum=0, dampening=0,\n                 weight_decay=0, nesterov=False,\n                 wd_after_momentum=False,\n                 materialize_master_grads=True,\n                 set_grad_none=False):\n        if lr is not required and lr < 0.0:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if momentum < 0.0:\n            raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n        if weight_decay < 0.0:\n            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n\n        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,\n                        weight_decay=weight_decay, nesterov=nesterov)\n        if nesterov and (momentum <= 0 or dampening != 0):\n            raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n        super(FusedSGD, self).__init__(params, defaults)\n\n        self.wd_after_momentum = wd_after_momentum\n        self.materialize_master_grads = materialize_master_grads\n        self.most_recent_scale = 1.0\n        self.scale_set_by_backward = False\n        self.set_grad_none = set_grad_none\n\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device)\n            self.multi_tensor_sgd = amp_C.multi_tensor_sgd\n        else:\n            raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')\n\n    def __setstate__(self, state):\n        super(FusedSGD, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('nesterov', False)\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedSGD, self).zero_grad()\n\n    def get_momentums(self, params):\n        momentums = []\n        first_run = True\n        for p in params:\n            param_state = self.state[p]\n            # torch.optim.SGD initializes momentum in the main loop, we have\n            # to do it here, and track whether or not we've done so, so that\n            # momentum application can be skipped in the main kernel.\n            if 'momentum_buffer' not in param_state:\n                first_run = True\n                buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)\n                momentums.append(buf)\n            else:\n                first_run = False\n                momentums.append(param_state['momentum_buffer'])\n        return momentums, first_run\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        explicit_master_params = (hasattr(self, \"_amp_stash\") and\n                                  hasattr(self._amp_stash, \"fp32_from_fp16_groups\"))\n\n        for gid, group in enumerate(self.param_groups):\n            weight_decay = group['weight_decay']\n            momentum = group['momentum']\n            dampening = group['dampening']\n            nesterov = group['nesterov']\n\n\n            # For each group, there are 3 possible combinations we need to consider:\n            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy\n            # 1. fp16, fp16, fp16, No\n            # 2. fp32, fp32, fp32, No\n            # 3. fp16, fp32, fp32, Yes\n\n            first_runs = [True, True]\n\n            # I think a bit of code divergence in exchange for naming clarity is worthwhile\n            if explicit_master_params:\n                stash = self._amp_stash\n\n                fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]\n                fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]\n                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n\n                if self.materialize_master_grads:\n                    fp16_model_params = [p for i, p in enumerate(\n                        stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]\n                    fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]\n                    fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]\n                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)\n\n                    fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,\n                                fp32_from_fp16_momentums, fp16_model_params]\n                else:\n                    fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]\n                    fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]\n                    fp32_from_fp16_params = [p for i, p in enumerate(\n                        stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]\n                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)\n\n                    fp16_set = [fp16_model_grads, fp32_from_fp16_params,\n                                fp32_from_fp16_momentums, fp16_model_params]\n\n                launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]\n            else:\n                fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]\n                fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]\n                fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)\n\n                fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]\n                fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]\n                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n\n                launch_sets = [[fp16_grads, fp16_params, fp16_momentums],\n                               [fp32_grads, fp32_params, fp32_momentums]]\n\n            for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):\n                assert len(launch_set[0]) == len(launch_set[1])\n                assert len(launch_set[0]) == len(launch_set[2])\n                if len(launch_set[0]) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_sgd,\n                        self._dummy_overflow_buf,\n                        launch_set,\n                        weight_decay,\n                        momentum,\n                        dampening,\n                        group['lr'],\n                        nesterov,\n                        first_run,\n                        self.wd_after_momentum,\n                        1.0/self.most_recent_scale)\n\n        self.most_recent_scale = 1.0\n        self.scale_set_by_backward = False\n\n        return loss\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/LARC.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn.parameter import Parameter\n\nclass LARC(object):\n    \"\"\"\n    :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,\n    in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive \n    local learning rate for each individual parameter. The algorithm is designed to improve\n    convergence of large batch training.\n     \n    See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.\n\n    In practice it modifies the gradients of parameters as a proxy for modifying the learning rate\n    of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.\n\n    ```\n    model = ...\n    optim = torch.optim.Adam(model.parameters(), lr=...)\n    optim = LARC(optim)\n    ```\n\n    It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.\n\n    ```\n    model = ...\n    optim = torch.optim.Adam(model.parameters(), lr=...)\n    optim = LARC(optim)\n    optim = apex.fp16_utils.FP16_Optimizer(optim)\n    ```\n\n    Args:\n        optimizer: Pytorch optimizer to wrap and modify learning rate for.\n        trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888\n        clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.\n        eps: epsilon kludge to help with numerical stability while calculating adaptive_lr\n    \"\"\"\n\n    def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):\n        self.optim = optimizer\n        self.trust_coefficient = trust_coefficient\n        self.eps = eps\n        self.clip = clip\n\n    def __getstate__(self):\n        return self.optim.__getstate__()\n\n    def __setstate__(self, state):\n        self.optim.__setstate__(state)\n\n    @property\n    def state(self):\n        return self.optim.state\n\n    def __repr__(self):\n        return self.optim.__repr__()\n\n    @property\n    def param_groups(self):\n        return self.optim.param_groups\n\n    @param_groups.setter\n    def param_groups(self, value):\n        self.optim.param_groups = value\n    \n    def state_dict(self):\n        return self.optim.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self.optim.load_state_dict(state_dict)\n\n    def zero_grad(self):\n        self.optim.zero_grad()\n\n    def add_param_group(self, param_group):\n        self.optim.add_param_group( param_group)\n\n    def step(self):\n        with torch.no_grad():\n            weight_decays = []\n            for group in self.optim.param_groups:\n                # absorb weight decay control from optimizer\n                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0\n                weight_decays.append(weight_decay)\n                group['weight_decay'] = 0\n                for p in group['params']:\n                    if p.grad is None:\n                        continue\n                    param_norm = torch.norm(p.data)\n                    grad_norm = torch.norm(p.grad.data)\n\n                    if param_norm != 0 and grad_norm != 0:\n                        # calculate adaptive lr + weight decay\n                        adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)\n\n                        # clip learning rate for LARC\n                        if self.clip:\n                            # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`\n                            adaptive_lr = min(adaptive_lr/group['lr'], 1)\n\n                        p.grad.data += weight_decay * p.data\n                        p.grad.data *= adaptive_lr\n\n        self.optim.step()\n        # return weight decay control to optimizer\n        for i, group in enumerate(self.optim.param_groups):\n            group['weight_decay'] = weight_decays[i]\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/README.md",
    "content": "## Distributed Data Parallel\n\ndistributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library.\n\n`apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with\ncomputation in the backward pass and bucketing smaller transfers to reduce the total number of\ntransfers required.\n\nmultiproc.py contains the source code for `apex.parallel.multiproc`, a launch utility that places one process on each of the node's available GPUs.\n\n#### [API Documentation](https://nvidia.github.io/apex/parallel.html)\n\n#### [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed)\n\n#### [Imagenet example with Mixed Precision](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)\n\n#### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex)\n\n### Synchronized Batch Normalization\n\n`apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`.\nIt reduces stats on the first (channel) dimension of the Tensor and accepts\narbitrary spatial dimensions.\n\n#### Installation\n\nApex provides two sync BN implementation:\n\n1. There is the Python-only implementation, which is the default implementation\nwhen install with `python setup.py install`.\nIt uses PyTorch primitive operations and distributed communication package from\n`torch.distributed`.\n\n   - _Python-only implementation requires input tensor to be of same data type as\nlayer_\n\n2. We also provide implementation with kernels through CUDA/C++ extension with\nimproved performance. We are experimenting with Welford and Kahan for reduction\nhoping to get better accuracy.\n   To use the kernel implementation, user need to install Apex with CUDA extension\nenabled `python setup.py install --cuda_ext`.\n\n   - _Custom kernel implementation supports fp16 input with fp32 layer as cudnn.\nThis is required to run imagenet example in fp16._\n\n   - _Currently kernel implementation only supports GPU._\n\n#### HowTo\n\n1. User could use `apex.parallel.SyncBatchNorm` by building their module with\nthe layer explicitly.\n\n```\nimport apex\ninput_t = torch.randn(3, 5, 20).cuda()\nsbn = apex.parallel.SyncBatchNorm(5).cuda()\noutput_t = sbn(input)\n```\n\n2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`.\n\n```\n# model is an instance of torch.nn.Module\nimport apex\nsync_bn_model = apex.parallel.convert_syncbn_model(model)\n```\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/__init__.py",
    "content": "import torch\n\nif hasattr(torch.distributed, 'ReduceOp'):\n    ReduceOp = torch.distributed.ReduceOp\nelif hasattr(torch.distributed, 'reduce_op'):\n    ReduceOp = torch.distributed.reduce_op\nelse:\n    ReduceOp = torch.distributed.deprecated.reduce_op\n\nfrom .distributed import DistributedDataParallel, Reducer\n# This is tricky because I'd like SyncBatchNorm to be exposed the same way\n# for both the cuda-enabled and python-fallback versions, and I don't want\n# to suppress the error information.\ntry:\n    import syncbn\n    from .optimized_sync_batchnorm import SyncBatchNorm\nexcept ImportError as err:\n    from .sync_batchnorm import SyncBatchNorm\n    SyncBatchNorm.syncbn_import_error = err\n\ndef convert_syncbn_model(module, process_group=None, channel_last=False):\n    '''\n    Recursively traverse module and its children to replace all instances of\n    ``torch.nn.modules.batchnorm._BatchNorm`` with :class:`apex.parallel.SyncBatchNorm`.\n\n    All ``torch.nn.BatchNorm*N*d`` wrap around\n    ``torch.nn.modules.batchnorm._BatchNorm``, so this function lets you easily switch\n    to use sync BN.\n\n    Args:\n        module (torch.nn.Module): input module\n\n    Example::\n\n        >>> # model is an instance of torch.nn.Module\n        >>> import apex\n        >>> sync_bn_model = apex.parallel.convert_syncbn_model(model)\n    '''\n    mod = module\n    if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):\n        return module\n    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):\n        mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)\n        mod.running_mean = module.running_mean\n        mod.running_var = module.running_var\n        mod.num_batches_tracked = module.num_batches_tracked\n        if module.affine:\n            mod.weight.data = module.weight.data.clone().detach()\n            mod.bias.data = module.bias.data.clone().detach()\n    for name, child in module.named_children():\n        mod.add_module(name, convert_syncbn_model(child,\n                                                  process_group=process_group,\n                                                  channel_last=channel_last))\n    # TODO(jie) should I delete model explicitly?\n    del module\n    return mod\n\ndef create_syncbn_process_group(group_size):\n    '''\n    Creates process groups to be used for syncbn of a give ``group_size`` and returns\n    process group that current GPU participates in.\n\n    ``group_size`` must divide the total number of GPUs (world_size).\n\n    ``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned.\n\n    ``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead.\n\n    Args:\n        group_size (int): number of GPU's to collaborate for sync bn\n\n    Example::\n\n        >>> # model is an instance of torch.nn.Module\n        >>> import apex\n        >>> group = apex.parallel.create_syncbn_process_group(group_size)\n    '''\n\n    if group_size==0:\n        return None\n\n    world_size = torch.distributed.get_world_size()\n    assert(world_size >= group_size)\n    assert(world_size % group_size == 0)\n\n    group=None\n    for group_num in (range(world_size//group_size)):\n        group_ids = range(group_num*group_size, (group_num+1)*group_size)\n        cur_group = torch.distributed.new_group(ranks=group_ids)\n        if (torch.distributed.get_rank()//group_size == group_num):\n            group = cur_group\n            #can not drop out and return here, every process must go through creation of all subgroups\n\n    assert(group is not None)\n    return group\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/distributed.py",
    "content": "import torch\nimport torch.distributed as dist\nfrom torch.nn.modules import Module\nfrom torch.autograd import Variable\nfrom collections import OrderedDict\nfrom itertools import chain\nimport copy\nimport importlib\nfrom ..multi_tensor_apply import multi_tensor_applier\n\nimported_flatten_impl = False\n\ndef import_flatten_impl():\n    global flatten_impl, unflatten_impl, imported_flatten_impl\n    try:\n        import apex_C\n        flatten_impl = apex_C.flatten\n        unflatten_impl = apex_C.unflatten\n    except ImportError:\n        print(\"Warning:  apex was installed without --cpp_ext.  Falling back to Python flatten and unflatten.\")\n        flatten_impl = torch._utils._flatten_dense_tensors\n        unflatten_impl = torch._utils._unflatten_dense_tensors\n    imported_flatten_impl = True\n\ndef flatten(bucket):\n    if not imported_flatten_impl:\n        import_flatten_impl()\n    return flatten_impl(bucket)\n\ndef unflatten(coalesced, bucket):\n    if not imported_flatten_impl:\n        import_flatten_impl()\n    return unflatten_impl(coalesced, bucket)\n\n# apply_dist_call requires that tensors in 'bucket' are all the same type.\ndef apply_flat_dist_call(bucket, call, extra_args=None):\n\n    coalesced = flatten(bucket)\n\n    if extra_args is not None:\n        call(coalesced, *extra_args)\n    else:\n        call(coalesced)\n\n    if call is dist.all_reduce:\n        coalesced /= dist.get_world_size()\n\n    for buf, synced in zip(bucket, unflatten(coalesced, bucket)):\n        buf.copy_(synced)\n\ndef split_half_float_double(tensors):\n    dtypes = [\"torch.cuda.HalfTensor\",  \"torch.cuda.FloatTensor\", \"torch.cuda.DoubleTensor\"]\n    buckets = []\n    for i, dtype in enumerate(dtypes):\n        bucket = [t for t in tensors if t.type() == dtype]\n        if bucket:\n            buckets.append(bucket)\n    return buckets\n\ndef split_by_type(tensors):\n    buckets = OrderedDict()\n    for tensor in tensors:\n        tp = tensor.type()\n        if tp not in buckets:\n            buckets[tp] = []\n        buckets[tp].append(tensor)\n    return buckets\n\n# flat_dist_call organizes 'tensors' by type.\ndef flat_dist_call(tensors, call, extra_args=None):\n    buckets = split_by_type(tensors)\n\n    for tp in buckets:\n        bucket = buckets[tp]\n        apply_flat_dist_call(bucket, call, extra_args)\n\n\ndef extract_tensors(maybe_tensor, tensor_list):\n    if torch.is_tensor(maybe_tensor):\n        tensor_list.append(maybe_tensor)\n    else:\n        try:\n            for item in maybe_tensor:\n                extract_tensors(item, tensor_list)\n        except TypeError:\n            return\n\n\nclass Reducer(object):\n    \"\"\"\n    :class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters\n    across processes.  :class:`Reducer` is intended to give the user additional control:\n    Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce\n    parameters during ``backward()``.\n    Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.\n    This enables, for example, delaying the allreduce to be carried out every\n    several iterations instead of every single iteration.\n\n    Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces\n    over the number of participating processes.\n\n    :class:`Reducer` is designed to work with the upstream launch utility script\n    ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.\n    When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.\n    It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.\n\n    Args:\n        module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced.  If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values.  If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.\n    \"\"\"\n\n    def __init__(self, module_or_grads_list):\n        if isinstance(module_or_grads_list, Module):\n            self.module = module_or_grads_list\n            flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )\n\n        else:\n            self.module = None\n            self.grads = []\n            extract_tensors(module_or_grads_list, self.grads)\n\n    def reduce(self):\n        if self.module:\n            grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]\n            flat_dist_call(grads, dist.all_reduce)\n        else:\n            flat_dist_call(self.grads, dist.all_reduce)\n\n\nclass DistributedDataParallel(Module):\n    \"\"\"\n    :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables\n    easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``.  Parameters are broadcast across participating processes on initialization, and gradients are\n    allreduced and averaged over processes during ``backward()``.\n\n    :class:`DistributedDataParallel` is optimized for use with NCCL.  It achieves high performance by\n    overlapping communication with computation during ``backward()`` and bucketing smaller gradient\n    transfers to reduce the total number of transfers required.\n\n    :class:`DistributedDataParallel` is designed to work with the upstream launch utility script\n    ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.\n    When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.\n    It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.\n\n    https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed shows detailed usage.\n    https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example\n    that combines :class:`DistributedDataParallel` with mixed precision training.\n\n    Args:\n        module: Network definition to be run in multi-gpu/distributed mode.\n        message_size (int, default=1e7): Minimum number of elements in a communication bucket.\n        delay_allreduce (bool, default=False):  Delay all communication to the end of the backward pass.  This disables overlapping communication with computation.\n        allreduce_trigger_params (list, optional, default=None):  If supplied, should contain a list of parameters drawn from the model.  Allreduces will be kicked off whenever one of these parameters receives its gradient (as opposed to when a bucket of size message_size is full).  At the end of backward(), a cleanup allreduce to catch any remaining gradients will also be performed automatically.  If allreduce_trigger_params is supplied, the message_size argument will be ignored.\n        allreduce_always_fp32 (bool, default=False):  Convert any FP16 gradients to FP32 before allreducing.  This can improve stability for widely scaled-out runs.\n        gradient_average (bool, default=True):  Option to toggle whether or not DDP averages the allreduced gradients over processes.  For proper scaling, the default value of True is recommended.\n        gradient_predivide_factor (float, default=1.0):  Allows perfoming the average of gradients over processes partially before and partially after the allreduce.  Before allreduce:  ``grads.mul_(1.0/gradient_predivide_factor)``.  After allreduce:  ``grads.mul_(gradient_predivide_factor/world size)``.  This can reduce the stress on the dynamic range of FP16 allreduces for widely scaled-out runs.\n\n    .. warning::\n        If ``gradient_average=False``, the pre-allreduce division (``grads.mul_(1.0/gradient_predivide_factor)``) will still be applied, but the post-allreduce gradient averaging (``grads.mul_(gradient_predivide_factor/world size)``) will be omitted.\n\n    \"\"\"\n\n    def __init__(self,\n                 module,\n                 message_size=10000000,\n                 delay_allreduce=False,\n                 shared_param=None,\n                 allreduce_trigger_params=None,\n                 retain_allreduce_buffers=False,\n                 allreduce_always_fp32=False,\n                 num_allreduce_streams=1,\n                 allreduce_communicators=None,\n                 gradient_average=True,\n                 gradient_predivide_factor=1.0,\n                 gradient_average_split_factor=None,\n                 prof=False):\n        super(DistributedDataParallel, self).__init__()\n\n        # Backward/forward compatibility around\n        # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and\n        # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86\n        if hasattr(dist, \"get_backend\"):\n            self._backend = dist.get_backend()\n            if hasattr(dist, \"DistBackend\"):\n                self.backend_enum_holder = dist.DistBackend\n            else:\n                self.backend_enum_holder = dist.Backend\n        else:\n            self._backend = dist._backend\n            self.backend_enum_holder = dist.dist_backend\n\n        self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False\n\n        self.prof = prof\n\n        self.allreduce_different_streams = (num_allreduce_streams > 1)\n        self.num_allreduce_streams = num_allreduce_streams\n        self.allreduce_communicators = allreduce_communicators\n        if self.allreduce_communicators:\n            assert len(allreduce_communicators[0]) == num_allreduce_streams\n            assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])\n            assert self.allreduce_different_streams\n\n        if self.allreduce_different_streams and delay_allreduce:\n            raise ValueError(\"self.allreduce_different_streams may only be used if delay_allreduce=False.\")\n\n        if shared_param is not None:\n            raise ValueError(\"shared_param is no longer supported as an option.  It was misleadingly named from the start.  It turns out overlapping communication with computation should work fine with shared parameters.  If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.\")\n\n        self.world_size = float(dist.get_world_size())\n\n        self.retain_allreduce_buffers = retain_allreduce_buffers\n        self.allreduce_always_fp32 = allreduce_always_fp32\n        self.gradient_average = gradient_average\n        self.gradient_predivide_factor = gradient_predivide_factor\n\n        self.custom_allreduce_triggers = False\n        if allreduce_trigger_params is not None:\n            if delay_allreduce:\n                raise ValueError(\"Setting allreduce_trigger_params is only valid if delay_allreduce=False.\")\n            self.custom_allreduce_triggers = True\n            self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])\n\n        self.delay_allreduce = delay_allreduce\n        self.message_size = message_size\n\n        self.main_stream = torch.cuda.current_stream()\n\n        self.bucket_streams = []\n        self.bucket_events = []\n\n        self.module = module\n\n        self._disable_allreduce = False\n\n        if self._backend == self.backend_enum_holder.NCCL:\n            for param in self.module.parameters():\n                assert param.is_cuda, \"NCCL backend only supports model parameters to be on GPU.\"\n\n        self.active_params = []\n\n        self.param_type_to_tmp_i = {\"torch.cuda.HalfTensor\" : 0,\n                                    \"torch.cuda.FloatTensor\" : 1,\n                                    \"torch.cuda.DoubleTensor\" : 2}\n\n        if multi_tensor_applier.available:\n            # TODO:  I really need to centralize the C++ backed imports\n            import amp_C\n            self.multi_tensor_scale = amp_C.multi_tensor_scale\n            self._overflow_buf = torch.cuda.IntTensor([0])\n\n        self.create_hooks()\n\n        flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )\n\n\n    def __setstate__(self, state):\n        super(DistributedDataParallel, self).__setstate__(state)\n        if self.allreduce_different_streams and delay_allreduce:\n            raise ValueError(\"self.allreduce_different_streams may only be used if delay_allreduce=False.\")\n\n        if self.delay_allreduce:\n            self.needs_refresh = True\n\n        self.bucket_streams = []\n        self.bucket_events = []\n\n\n    def __getstate__(self):\n        attrs = copy.copy(self.__dict__)\n        if self._backend != self.backend_enum_holder.NCCL:\n            del attrs['self.bucket_streams']\n            del attrs['self.bucket_events']\n            return attrs\n\n    def enable_allreduce(self):\n        self._disable_allreduce = False\n\n    def disable_allreduce(self):\n        self._disable_allreduce = True\n\n    # Broadcast rank 0's bucket structure across all processes, and have all processes\n    # regenerate their bucket structures to match.\n    def sync_bucket_structure(self):\n        # Append leftover buckets\n        for tmp_bucket in self.tmp_buckets:\n            if len(tmp_bucket) > 0:\n                self.active_i_buckets.append(tmp_bucket)\n\n        self.num_buckets = len(self.active_i_buckets)\n        self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]\n\n        info_tensor = torch.cuda.IntTensor([self.num_buckets] +\n                                           self.bucket_sizes +\n                                           list(chain(*self.active_i_buckets)))\n\n        dist.broadcast(info_tensor, 0)\n\n        info = [int(entry) for entry in info_tensor]\n\n        self.num_buckets = info[0]\n        self.bucket_sizes = info[1:self.num_buckets + 1]\n        self.buckets = [[None for _ in range(self.bucket_sizes[i])]\n                        for i in range(self.num_buckets)]\n        # Technically, active_i_buckets' work is done.  But the information is still useful to\n        # keep around.  Therefore, refresh active_i_buckets based on rank 0 as well.\n        self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]\n                                 for i in range(self.num_buckets)]\n\n        flattened_buckets = info[self.num_buckets + 1:]\n        flat_i = 0\n        for bucket_idx in range(self.num_buckets):\n            for bucket_loc in range(self.bucket_sizes[bucket_idx]):\n                param_i = flattened_buckets[flat_i]\n                self.active_i_buckets[bucket_idx][bucket_loc] = param_i\n                self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)\n                flat_i += 1\n\n\n    def create_hooks(self):\n        # Fallback hook that's only called at the end of backward.\n        # Used if you deliberately want to delay allreduces to the end, or to refresh the\n        # bucket structure that will be used to overlap communication with computation in later\n        # iterations.\n        def allreduce_params():\n            # Bucket record refresh\n            if not self.delay_allreduce:\n                if self.needs_refresh:\n                    self.sync_bucket_structure()\n\n                    self.needs_refresh = False\n\n            self.allreduce_fallback()\n\n\n        def overlapping_backward_epilogue():\n            for stream, event in zip(self.bucket_streams, self.bucket_events):\n                stream.record_event(event)\n                torch.cuda.current_stream().wait_event(event)\n\n            # Sanity checks that all the buckets were kicked off\n            if self.next_bucket != self.num_buckets:\n                raise RuntimeError(\"In epilogue, next_bucket ({}) != num_buckets ({}).  \".format(\n                                   self.next_bucket, self.num_buckets),\n                                   \"This probably indicates some buckets were not allreduced.\")\n\n            for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):\n                if actual != expected:\n                    raise RuntimeError(\"Some param buckets were not allreduced.\")\n\n\n        self.grad_accs = []\n        for param in self.module.parameters():\n            if param.requires_grad:\n                def wrapper(param):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n\n                    def allreduce_hook(*unused):\n                        if self.prof:\n                            torch.cuda.nvtx.range_push(\"allreduce_hook\")\n\n                        if not self._disable_allreduce:\n                            if self.delay_allreduce or self.needs_refresh:\n                                # TODO:  How do we want to handle multiple backward passes between\n                                # each forward, e.g., backward passes with retain_graph=True?\n                                # needs_refresh and callback_queued are both vulnerable states.\n                                if not self.delay_allreduce and self.needs_refresh:\n                                    # Use the backward pass to build the bucket structure on the fly.\n                                    active_i = self.param_id_to_active_i[id(param)]\n\n                                    # Float, half, and double tensors are grouped into buckets separately.\n                                    current_type = self.param_type_to_tmp_i[param.type()]\n\n                                    self.tmp_buckets[current_type].append(active_i)\n\n                                    ship_tmp_bucket = False\n                                    if self.custom_allreduce_triggers:\n                                        if id(param) in self.allreduce_trigger_params:\n                                            ship_tmp_bucket = True\n                                    else:\n                                        self.tmp_numels[current_type] += param.numel()\n                                        if self.tmp_numels[current_type] >= self.message_size:\n                                            ship_tmp_bucket = True\n\n                                    # To consider:  If custom_allreduce_triggers are in use, ship all\n                                    # tmp_buckets, not just tmp_buckets[current_type].\n                                    if ship_tmp_bucket:\n                                        self.active_i_buckets.append(self.tmp_buckets[current_type])\n                                        self.tmp_buckets[current_type] = []\n                                        self.tmp_numels[current_type] = 0\n\n                                if not self.callback_queued:\n                                    Variable._execution_engine.queue_callback(allreduce_params)\n                                    self.callback_queued = True\n                            else:\n                                if not self.callback_queued:\n                                    Variable._execution_engine.queue_callback(overlapping_backward_epilogue)\n                                    self.callback_queued = True\n\n                                self.comm_ready_buckets(param)\n\n                        if self.prof:\n                            torch.cuda.nvtx.range_pop()\n\n                    grad_acc.register_hook(allreduce_hook)\n                    self.grad_accs.append(grad_acc)\n\n                wrapper(param)\n\n\n    def _stream_this_bucket(self, bucket_idx):\n        if self.allreduce_different_streams:\n            return self.bucket_streams[bucket_idx%self.num_allreduce_streams]\n        else:\n            return self.bucket_streams[0]\n\n\n    def _event_this_bucket(self, bucket_idx):\n        if self.allreduce_different_streams:\n            return self.bucket_events[bucket_idx%self.num_allreduce_streams]\n        else:\n            return self.bucket_events[0]\n\n\n    def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):\n        tensor = flatten(bucket)\n\n        if force_default_stream:\n            bucket_stream = self.main_stream\n        else:\n            bucket_stream = self._stream_this_bucket(bucket_idx)\n            bucket_event = self._event_this_bucket(bucket_idx)\n            torch.cuda.current_stream().record_event(bucket_event)\n            bucket_stream.wait_event(bucket_event)\n\n        with torch.cuda.stream(bucket_stream):\n            # self.main_stream.wait_stream(torch.cuda.current_stream())\n            # torch.cuda.synchronize()\n\n            tensor_to_allreduce = tensor\n\n            if self.allreduce_always_fp32:\n                tensor_to_allreduce = tensor.float()\n\n            if self.gradient_predivide_factor != 1.0:\n                tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)\n\n            if self.allreduce_different_streams and not force_default_stream:\n                dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams])\n            else:\n                dist.all_reduce(tensor_to_allreduce)\n\n            if self.gradient_average:\n                tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)\n\n            if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:\n                tensor.copy_(tensor_to_allreduce)\n\n            if not self.retain_allreduce_buffers:\n                if multi_tensor_applier.available:\n                    multi_tensor_applier(\n                        self.multi_tensor_scale,\n                        self._overflow_buf,\n                        [unflatten(tensor, bucket), bucket],\n                        1.0)\n                else:\n                    for buf, synced in zip(bucket, unflatten(tensor, bucket)):\n                        buf.copy_(synced)\n\n            # I think we actually do need this here.  After allreduce_bucket returns, tensor will\n            # eventually go out of scope and die, at which point it could otherwise be freed for\n            # further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.\n            tensor.record_stream(bucket_stream)\n\n        return tensor\n\n\n    def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):\n        allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)\n        if self.retain_allreduce_buffers:\n            if self.allreduce_buffers[bucket_idx] is not None:\n                raise RuntimeError(\"The backward pass is attempting to replace an already-filled \"\n                                   \"allreduce buffer.  This is almost certainly an error.\")\n            self.allreduce_buffers[bucket_idx] = allreduced\n            for view, grad in zip(unflatten(allreduced, bucket), bucket):\n                grad.data = view\n            # for buf, synced in zip(bucket, unflatten(allreduced, bucket)):\n            #     buf.copy_(synced)\n\n\n    def allreduce_fallback(self):\n        for stream, event in zip(self.bucket_streams, self.bucket_events):\n            stream.record_event(event)\n            torch.cuda.current_stream().wait_event(event)\n\n        if self.retain_allreduce_buffers:\n            grads = [param.grad for param in self.module.parameters() if param.grad is not None]\n        else:\n            grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]\n\n        split_buckets = split_half_float_double(grads)\n\n        # If retain_allreduce_buffers is True and delay_allreduce is False,\n        # this will only be done during the first backward pass, ignored by the\n        # training script, and overwritten in the next forward pass.  So it's harmless.\n        if self.retain_allreduce_buffers:\n            self.allreduce_buffers = [None for _ in range(len(split_buckets))]\n\n        for i, bucket in enumerate(split_buckets):\n            allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)\n\n\n    def comm_ready_buckets(self, param):\n        # Need to do this in every hook for compatibility with Ruberry's streaming backward PR.\n        # self.reduction_stream.wait_stream(torch.cuda.current_stream())\n        if self.prof:\n            torch.cuda.nvtx.range_push(\"comm_ready_buckets\")\n\n        bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]\n\n        if self.buckets[bucket_idx][bucket_loc] is not None:\n            raise RuntimeError(\"The backward pass is attempting to replace an already-filled \"\n                               \"bucket slot.  This is almost certainly an error.\")\n\n        if self.retain_allreduce_buffers:\n            self.buckets[bucket_idx][bucket_loc] = param.grad\n        else:\n            self.buckets[bucket_idx][bucket_loc] = param.grad.data\n\n        self.buckets_ready_size[bucket_idx] += 1\n\n        if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:\n            if bucket_idx == self.next_bucket:\n                self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)\n\n                self.next_bucket += 1\n\n                # Reversing upstream's logic here, because we constructed our buckets based on\n                # the order things were received during backward.\n                if len(self.ready_buckets_not_reduced) > 0:\n                    sorted_todo = sorted(self.ready_buckets_not_reduced)\n                    for i in sorted_todo:\n                        # Nothing can be reduced now\n                        if i > self.next_bucket:\n                            break\n                        elif i == self.next_bucket:\n                            self.allreduce_maybe_retain(self.buckets[i], i)\n                            self.ready_buckets_not_reduced.remove(i)\n                            self.next_bucket += 1\n                        else:\n                            raise ValueError(\"i should always be >= next_bucket\")\n            else:\n                self.ready_buckets_not_reduced.add(bucket_idx)\n\n        if self.prof:\n            torch.cuda.nvtx.range_pop()\n\n\n    def forward(self, *inputs, **kwargs):\n        result = self.module(*inputs, **kwargs)\n\n        if self.prof:\n            torch.cuda.nvtx.range_push(\"forward pass DDP logic\")\n\n        if not self._disable_allreduce:\n            if not self.delay_allreduce:\n                param_list = [param for param in self.module.parameters() if param.requires_grad]\n\n                # Conditions under which to refresh self.record\n                # Forward has the authority to set needs_refresh to True, but only allreduce_params\n                # in backward has the authority to set needs_refresh to False.\n                # Parentheses are not necessary for correct order of operations, but make the intent clearer.\n                if ((not self.active_params) or\n                    (len(param_list) != len(self.active_params)) or\n                    any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):\n                    self.needs_refresh = True\n\n                if self.needs_refresh:\n                    self.active_i_buckets = []\n                    self.buckets = []\n                    self.tmp_buckets = [[], [], []] # [running half, float, double buckets]\n                    self.tmp_numels = [0, 0, 0]\n                    self.bucket_sizes = []\n                    self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}\n                    self.param_id_to_bucket = {}\n                    self.bucket_pgs = []\n                    self.bucket_streams = []\n                    self.bucket_events = []\n                else:\n                    # self.buckets = [[None for _ in range(self.bucket_sizes[i])]\n                    #                 for i in range(self.num_buckets)]\n                    if not self.buckets:\n                        self.buckets = [[None for _ in range(self.bucket_sizes[i])]\n                                        for i in range(self.num_buckets)]\n                    else:\n                        assert len(self.buckets) == self.num_buckets, \"len(buckets) = {}, expected {}\".format(\n                            len(self.buckets), self.num_buckets)\n                        for b, bucket in enumerate(self.buckets):\n                            assert len(bucket) == self.bucket_sizes[b], \"len(buckets[{}]) = {}, expected {})\".format(\n                                b, len(buckets[b]), self.bucket_sizes[b])\n                            for i in range(len(bucket)):\n                                bucket[i] = None\n\n                    if self.allreduce_communicators:\n                        self.bucket_pgs = self.allreduce_communicators[0]\n                        self.bucket_streams = self.allreduce_communicators[1]\n                        self.bucket_events = [torch.cuda.Event(enable_timing=False,\n                                            blocking=False) for _ in range(self.num_allreduce_streams)]\n                    else:\n                        if self.allreduce_different_streams:\n                            if not self.bucket_pgs:\n                                self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]\n                                for i, bg in enumerate(self.bucket_pgs):\n                                    print(\"rank {} created group {} with backend {}\".format(\n                                          dist.get_rank(), i, dist.get_backend(bg)))\n                        if self.allreduce_different_streams:\n                            if not self.bucket_streams:\n                                self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]\n                                self.bucket_events = [torch.cuda.Event(enable_timing=False,\n                                                      blocking=False) for _ in range(self.num_allreduce_streams)]\n                        else:\n                            if not self.bucket_streams:\n                                self.bucket_streams = [torch.cuda.Stream()]\n                                self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]\n\n                    self.buckets_ready_size = [0 for i in range(self.num_buckets)]\n                    if(self.retain_allreduce_buffers):\n                        self.allreduce_buffers = [None for _ in range(self.num_buckets)]\n                    self.next_bucket = 0\n                    self.ready_buckets_not_reduced = set()\n\n                self.active_params = param_list\n\n            self.callback_queued = False\n\n        if self.prof:\n            torch.cuda.nvtx.range_pop()\n\n        return result\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/multiproc.py",
    "content": "import torch\nimport sys\nimport subprocess\n\ndef docstring_hack():\n    \"\"\"\n    Multiproc file which will launch a set of processes locally for multi-gpu\n    usage: python -m apex.parallel.multiproc main.py ...\n    \"\"\"\n    pass\n\nargslist = list(sys.argv)[1:]\nworld_size = torch.cuda.device_count()\n\nif '--world-size' in argslist:\n    world_size = int(argslist[argslist.index('--world-size')+1])\nelse:\n    argslist.append('--world-size')\n    argslist.append(str(world_size))\n\nworkers = []\n\nfor i in range(world_size):\n    if '--rank' in argslist:\n        argslist[argslist.index('--rank')+1] = str(i)\n    else:\n        argslist.append('--rank')\n        argslist.append(str(i))\n    stdout = None if i == 0 else open(\"GPU_\"+str(i)+\".log\", \"w\")\n    print(argslist)\n    p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)\n    workers.append(p)\n\nfor p in workers:\n    p.wait()\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/optimized_sync_batchnorm.py",
    "content": "import torch\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn import functional as F\n\nimport syncbn\nfrom .optimized_sync_batchnorm_kernel import SyncBatchnormFunction\n\n\nclass SyncBatchNorm(_BatchNorm):\n    \"\"\"\n    synchronized batch normalization module extented from `torch.nn.BatchNormNd`\n    with the added stats reduction across multiple processes.\n    :class:`apex.parallel.SyncBatchNorm` is designed to work with\n    `DistributedDataParallel`.\n\n    When running in training mode, the layer reduces stats across all processes\n    to increase the effective batchsize for normalization layer. This is useful\n    in applications where batch size is small on a given process that would\n    diminish converged accuracy of the model. The model uses collective\n    communication package from `torch.distributed`.\n\n    When running in evaluation mode, the layer falls back to\n    `torch.nn.functional.batch_norm`\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``True``\n        process_group: pass in a process group within which the stats of the\n            mini-batch is being synchronized. ``None`` for using default process\n            group\n        channel_last: a boolean value that when set to ``True``, this module\n            take the last dimension of the input tensor to be the channel\n            dimension. Default: False\n\n    Examples::\n        >>> # channel first tensor\n        >>> sbn = apex.parallel.SyncBatchNorm(100).cuda()\n        >>> inp = torch.randn(10, 100, 14, 14).cuda()\n        >>> out = sbn(inp)\n        >>> inp = torch.randn(3, 100, 20).cuda()\n        >>> out = sbn(inp)\n        >>> # channel last tensor\n        >>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()\n        >>> inp = torch.randn(10, 14, 14, 100).cuda()\n    \"\"\"\n\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):\n        super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)\n        self.process_group = process_group\n        self.channel_last = channel_last\n        self.fuse_relu = fuse_relu\n\n    def _specify_process_group(self, process_group):\n        self.process_group = process_group\n\n    def _specify_channel_last(self, channel_last):\n        self.channel_last = channel_last\n\n    def forward(self, input, z = None):\n        # if input.dim() == 2, we switch to channel_last for efficient memory accessing\n        channel_last = self.channel_last if input.dim() != 2 else True\n\n        if not self.training and self.track_running_stats and not channel_last and not self.fuse_relu and z == None:\n            # fall back to pytorch implementation for inference\n            return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)\n        else:\n            exponential_average_factor = 0.0\n            if self.training and self.track_running_stats:\n                self.num_batches_tracked += 1\n                if self.momentum is None:\n                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)\n                else:\n                    exponential_average_factor = self.momentum\n            return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last, self.fuse_relu)\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/optimized_sync_batchnorm_kernel.py",
    "content": "import torch\nfrom torch.autograd.function import Function\n\nimport syncbn\nfrom apex.parallel import ReduceOp\n\nclass SyncBatchnormFunction(Function):\n\n    @staticmethod\n    def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):\n        input = input.contiguous()\n        world_size = 0\n\n        mean = None\n        var_biased = None\n        inv_std = None\n        var = None\n        out = None\n        count = None\n        if track_running_stats:\n            if channel_last:\n                count = int(input.numel()/input.size(-1))\n                mean, var_biased = syncbn.welford_mean_var_c_last(input)\n                num_channels = input.size(-1)\n            else:\n                count = int(input.numel()/input.size(1))\n                mean, var_biased = syncbn.welford_mean_var(input)\n                num_channels = input.size(1)\n\n            if torch.distributed.is_initialized():\n                if not process_group:\n                    process_group = torch.distributed.group.WORLD\n                device = mean.device\n                world_size = torch.distributed.get_world_size(process_group)\n\n                count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count)\n                combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0)\n                combined_list = [torch.empty_like(combined) for k in range(world_size)]\n                torch.distributed.all_gather(combined_list, combined, process_group)\n                combined = torch.stack(combined_list, dim=0)\n                mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)\n                count_all = count_all.view(-1)\n                mean, var, inv_std = syncbn.welford_parallel(mean_all, invstd_all, count_all.to(torch.int32), eps)\n            else:\n                device = mean.device\n                count_all = torch.cuda.IntTensor([count], device=device)\n                inv_std = 1.0 / torch.sqrt(var_biased + eps)\n                var = var_biased * (count) / (count-1)\n\n            if count == 1 and world_size < 2:\n                raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))\n\n            r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()\n            r_v_inc = var if running_variance.dtype != torch.float16 else var.half()\n            running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc\n            running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc\n        else:\n            mean = running_mean.data\n            inv_std = 1.0 / torch.sqrt(running_variance.data + eps)\n\n        ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all.to(torch.int32))\n        ctx.process_group = process_group\n        ctx.channel_last = channel_last\n        ctx.world_size = world_size\n        ctx.fuse_relu = fuse_relu\n\n        if channel_last:\n            out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)\n        else:\n            out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)\n\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_output = grad_output.contiguous()\n        # mini batch mean & var are calculated by forward path.\n        # mu = 1./N*np.sum(h, axis = 0)\n        # var = 1./N*np.sum((h-mu)**2, axis = 0)\n        saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors\n        process_group = ctx.process_group\n        channel_last = ctx.channel_last\n        world_size = ctx.world_size\n        fuse_relu = ctx.fuse_relu\n        grad_input = grad_z = grad_weight = grad_bias = None\n\n        if fuse_relu:\n            grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)\n        if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:\n            grad_z = grad_output.clone()\n\n        # TODO: update kernel to not pre_divide by item_num\n        if channel_last:\n            sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)\n        else:\n            sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)\n\n        # calculate grad_input\n        if ctx.needs_input_grad[0]:\n\n            if torch.distributed.is_initialized():\n                num_channels = sum_dy.shape[0]\n                combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)\n                torch.distributed.all_reduce(\n                    combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)\n                sum_dy, sum_dy_xmu = torch.split(combined, num_channels)\n\n            if channel_last:\n                grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)\n            else:\n                grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)\n\n        if weight is None or not ctx.needs_input_grad[2]:\n            grad_weight = None\n\n        if weight is None or not ctx.needs_input_grad[3]:\n            grad_bias = None\n\n        return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/sync_batchnorm.py",
    "content": "import torch\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn import functional as F\n\nfrom .sync_batchnorm_kernel import SyncBatchnormFunction\nfrom apex.parallel import ReduceOp\n\n\nclass SyncBatchNorm(_BatchNorm):\n    \"\"\"\n    synchronized batch normalization module extented from ``torch.nn.BatchNormNd``\n    with the added stats reduction across multiple processes.\n    :class:`apex.parallel.SyncBatchNorm` is designed to work with\n    ``DistributedDataParallel``.\n\n    When running in training mode, the layer reduces stats across all processes\n    to increase the effective batchsize for normalization layer. This is useful\n    in applications where batch size is small on a given process that would\n    diminish converged accuracy of the model. The model uses collective\n    communication package from ``torch.distributed``.\n\n    When running in evaluation mode, the layer falls back to\n    ``torch.nn.functional.batch_norm``.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``True``\n\n    Example::\n\n        >>> sbn = apex.parallel.SyncBatchNorm(100).cuda()\n        >>> inp = torch.randn(10, 100, 14, 14).cuda()\n        >>> out = sbn(inp)\n        >>> inp = torch.randn(3, 100, 20).cuda()\n        >>> out = sbn(inp)\n    \"\"\"\n\n    warned = False\n\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):\n        if channel_last == True:\n            raise AttributeError(\"channel_last is not supported by primitive SyncBatchNorm implementation. Try install apex with `--cuda_ext` if channel_last is desired.\")\n\n        if not SyncBatchNorm.warned:\n            if hasattr(self, \"syncbn_import_error\"):\n                print(\"Warning:  using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext.  The exception raised when attempting to import the cuda backend was: \", self.syncbn_import_error)\n            else:\n                print(\"Warning:  using Python fallback for SyncBatchNorm\")\n            SyncBatchNorm.warned = True\n\n        super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)\n        self.process_group = process_group\n\n    def _specify_process_group(self, process_group):\n        self.process_group = process_group\n\n    def forward(self, input):\n        torch.cuda.nvtx.range_push(\"sync_bn_fw_with_mean_var\")\n        mean = None\n        var = None\n        cast = None\n        out = None\n\n        # casting to handle mismatch input type to layer type\n        if self.running_mean is not None:\n            if self.running_mean.dtype != input.dtype:\n                input = input.to(self.running_mean.dtype)\n                cast = input.dtype\n        elif self.weight is not None:\n            if self.weight.dtype != input.dtype:\n                input = input.to(self.weight.dtype)\n                cast = input.dtype\n\n        if not self.training and self.track_running_stats:\n            # fall back to pytorch implementation for inference\n            torch.cuda.nvtx.range_pop()\n            out = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)\n        else:\n            process_group = self.process_group\n            world_size = 1\n            if not self.process_group:\n                process_group = torch.distributed.group.WORLD\n            self.num_batches_tracked += 1\n            with torch.no_grad():\n                channel_first_input = input.transpose(0, 1).contiguous()\n                squashed_input_tensor_view = channel_first_input.view(\n                    channel_first_input.size(0), -1)\n                # total number of data points for each variance entry. Used to calculate unbiased variance estimate\n                m = None\n                local_m = float(squashed_input_tensor_view.size()[1])\n                local_mean = torch.mean(squashed_input_tensor_view, 1)\n                local_sqr_mean = torch.pow(\n                    squashed_input_tensor_view, 2).mean(1)\n                if torch.distributed.is_initialized():\n                    world_size = torch.distributed.get_world_size(process_group)\n                    torch.distributed.all_reduce(\n                        local_mean, ReduceOp.SUM, process_group)\n                    mean = local_mean / world_size\n                    torch.distributed.all_reduce(\n                        local_sqr_mean, ReduceOp.SUM, process_group)\n                    sqr_mean = local_sqr_mean / world_size\n                    m = local_m * world_size\n                else:\n                    m = local_m\n                    mean = local_mean\n                    sqr_mean = local_sqr_mean\n                # var(x) = E (( x - mean_x ) ** 2)\n                #        = 1 / N * sum ( x - mean_x ) ** 2\n                #        = 1 / N * sum (x**2) - mean_x**2\n                var = sqr_mean - mean.pow(2)\n\n                if self.running_mean is not None:\n                    self.running_mean = self.momentum * mean + \\\n                        (1 - self.momentum) * self.running_mean\n                if self.running_var is not None:\n                    # as noted by the paper, we used unbiased variance estimate of the mini-batch\n                    # Var[x] = m / (m-1) * Eb (sample_variance)\n                    self.running_var = m / \\\n                        (m-1) * self.momentum * var + \\\n                        (1 - self.momentum) * self.running_var\n            torch.cuda.nvtx.range_pop()\n            out = SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size)\n        return out.to(cast)\n"
  },
  {
    "path": "KoSentenceT5/apex/parallel/sync_batchnorm_kernel.py",
    "content": "import torch\nfrom torch.autograd.function import Function\n\nfrom apex.parallel import ReduceOp\n\n\nclass SyncBatchnormFunction(Function):\n\n    @staticmethod\n    def forward(ctx, input, weight, bias, running_mean, running_variance, eps, process_group, world_size):\n        torch.cuda.nvtx.range_push(\"sync_BN_fw\")\n        # transpose it to channel last to support broadcasting for input with different rank\n        c_last_input = input.transpose(1, -1).contiguous().clone()\n\n        ctx.save_for_backward(c_last_input, weight, bias,\n                              running_mean, running_variance)\n        ctx.eps = eps\n        ctx.process_group = process_group\n        ctx.world_size = world_size\n\n        c_last_input = (c_last_input - running_mean) / \\\n            torch.sqrt(running_variance + eps)\n\n        if weight is not None:\n            c_last_input = c_last_input * weight\n        if bias is not None:\n            c_last_input = c_last_input + bias\n\n        torch.cuda.nvtx.range_pop()\n        return c_last_input.transpose(1, -1).contiguous().clone()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        torch.cuda.nvtx.range_push(\"sync_BN_bw\")\n        # mini batch mean & var are calculated by forward path.\n        # mu = 1./N*np.sum(h, axis = 0)\n        # var = 1./N*np.sum((h-mu)**2, axis = 0)\n        c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors\n\n        eps = ctx.eps\n        process_group = ctx.process_group\n        world_size = ctx.world_size\n        grad_input = grad_weight = grad_bias = None\n        num_features = running_mean.size()[0]\n\n        # transpose it to channel last to support broadcasting for input with different rank\n        torch.cuda.nvtx.range_push(\"carilli field\")\n        c_last_grad = grad_output.transpose(1, -1).contiguous()\n        # squash non-channel dimension so we can easily calculate mean\n        c_grad = c_last_grad.view(-1, num_features).contiguous()\n        torch.cuda.nvtx.range_pop()\n\n        # calculate grad_input\n        if ctx.needs_input_grad[0]:\n            # dh = gamma * (var + eps)**(-1. / 2.) * (dy - np.mean(dy, axis=0)\n            #     - (h - mu) * (var + eps)**(-1.0) * np.mean(dy * (h - mu), axis=0))\n            mean_dy = c_grad.mean(0)\n            mean_dy_xmu = (c_last_grad * (c_last_input -\n                                          running_mean)).view(-1, num_features).mean(0)\n            if torch.distributed.is_initialized():\n                torch.distributed.all_reduce(\n                    mean_dy, ReduceOp.SUM, process_group)\n                mean_dy = mean_dy / world_size\n                torch.distributed.all_reduce(\n                    mean_dy_xmu, ReduceOp.SUM, process_group)\n                mean_dy_xmu = mean_dy_xmu / world_size\n            c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (\n                running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)\n            if weight is not None:\n                c_last_grad_input.mul_(weight)\n            grad_input = c_last_grad_input.transpose(1, -1).contiguous()\n\n        # calculate grad_weight\n        grad_weight = None\n        if weight is not None and ctx.needs_input_grad[1]:\n            # dgamma = np.sum((h - mu) * (var + eps)**(-1. / 2.) * dy, axis=0)\n            grad_weight = ((c_last_input - running_mean) / torch.sqrt(\n                running_variance + eps) * c_last_grad).view(-1, num_features).sum(0)\n\n        # calculate grad_bias\n        grad_bias = None\n        if bias is not None and ctx.needs_input_grad[2]:\n            # dbeta = np.sum(dy, axis=0)\n            grad_bias = c_grad.sum(0)\n\n        torch.cuda.nvtx.range_pop()\n        return grad_input, grad_weight, grad_bias, None, None, None, None, None\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/FAQs.md",
    "content": "1. How do I intercept the Adam optimizer in APEX ?\n\n\t```python\n\tfrom apex import pyprof\n\timport fused_adam_cuda\n\tpyprof.nvtx.wrap(fused_adam_cuda, 'adam')\n\t```\n\n2. If you are using JIT and/or AMP, the correct initialization sequence is\n\t1. Let any JIT to finish.\n\t2. Initlialize pyprof `pyprof.nvtx.init()`.\n\t3. Initialize AMP.\n\n3. How do I profile with `torch.distributed.launch` ?\n\n\t```python\n\tnvprof -f -o net%p.sql \\\n\t\t--profile-from-start off \\\n\t\t--profile-child-processes \\\n\t\tpython -m torch.distributed.launch net.py\n\t```\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/README.md",
    "content": "## PyProf - PyTorch Profiling tool\n\n### What does this tool do?                                                                                                                                                                                                                  \n\nAnalyzing the performance of deep neural networks is hard. Getting kernels out of [NvProf]([https://developer.nvidia.com/nvidia-visual-profiler](https://developer.nvidia.com/nvidia-visual-profiler)) or [NSight Compute]([https://developer.nvidia.com/nsight-compute](https://developer.nvidia.com/nsight-compute)) provides some generic kernel name and its execution time, but not detailed information regarding the following:\n\n - Which layer launched it: e.g. the association of `ComputeOffsetsKernel` with a concrete PyTorch layer or API is not obvious.\n - What the tensor dimensions and precision were: without knowing the tensor dimensions and precision, it's impossible to reason about whether the actual (silicon) kernel time is close to maximum performance of such a kernel on the GPU. Knowing the tensor dimensions and precision, we can figure out the FLOPs and bandwidth required by a layer, and then determine how close to maximum performance the kernel is for that operation.\n - Forward-backward correlation: currently it's very hard to determine what the forward pass step was that resulted in the particular weight and data gradients (wgrad, dgrad), which makes it difficult to determine the tensor dimensions required by these backprop steps to assess their performance.\n - Did the kernel use [Tensor Cores]([https://www.youtube.com/watch?v=yyR0ZoCeBO8](https://www.youtube.com/watch?v=yyR0ZoCeBO8))?\n - Which line in the user's code resulted in launching this particular kernel (program trace)?\n\nPyProf addresses all of the issues above by:\n\n 1. Instrumenting PyTorch operations to capture the tensor dimensions and precision using [NVTX](https://devblogs.nvidia.com/cuda-pro-tip-generate-custom-application-profile-timelines-nvtx). This information is recorded at profile capture time, e.g. using [NvProf](https://developer.nvidia.com/nvidia-visual-profiler).\n 2. Querying the record produced by the profiler to correlate the kernel name and duration with PyTorch API/layer name, tensor dimensions, tensor precision, as well as calculating FLOPs and bandwidth for common operations. In addition, extra information from the profile is added for use by CUDA professionals, such as CUDA launch parameters (block/grid dimensions).\n\nRegarding FLOP and bandwidth implementations, these are usually quite straightforward. For example, for matrices A<sub>MxK</sub> and B<sub>KxN</sub>, the FLOP count for a matrix multiplication is 2 * M * N * K, and bandwidth is M * K + N * K + M * N. Note that these numbers are based on the algorithm, not the actual performance of the specific kernel. For more details, see NVIDIA's [Deep Learning Performance Guide](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html).\n\nArmed with such information, the user can determine various issues to help them tune the network. For instance, according to the [Tensor Core Performance Guide]([https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html)), the M, N and K dimensions that result in Tensor Core usage need to be divisible by 8. In fact, PyProf comes with a flag that lets the user obtain information regarding whether Tensor Cores were used by the kernel. Other useful information might include knowing that a particular kernel did not exploit much thread parallelism, as determined by the grid/block dimensions. Since many PyTorch kernels are open-source (or even custom written by the user, as in [CUDA Extensions]([https://pytorch.org/tutorials/advanced/cpp_extension.html](https://pytorch.org/tutorials/advanced/cpp_extension.html))), this provides the user with information that helps root cause performance issues and prioritize optimization work.\n\n\n### How to get started?\n\n1. Add the following lines to your PyTorch network:\n\n    ```python\n    import torch.cuda.profiler as profiler\n    from apex import pyprof\n    pyprof.nvtx.init()\n    ```\n\n    Run the training/inference loop with the [PyTorch's NVTX context manager](https://pytorch.org/docs/stable/_modules/torch/autograd/profiler.html#emit_nvtx)\n    `with torch.autograd.profiler.emit_nvtx()`. Optionally, you can\n    use `profiler.start()` and `profiler.stop()` to pick an iteration\n    (say after warm-up) for which you would like to capture data.\n    Here's an example:\n\n    ```python\n    iters = 500\n    iter_to_capture = 100\n\n    # Define network, loss function, optimizer etc.\n\n    # PyTorch NVTX context manager\n    with torch.autograd.profiler.emit_nvtx():\n\n        for iter in range(iters):\n\n            if iter == iter_to_capture:\n                profiler.start()\n\n            output = net(images)\n            loss = criterion(output, labels)\n            loss.backward()\n            optimizer.step()\n\n            if iter == iter_to_capture:\n                profiler.stop()\n    ```\n\n2. Run NVprof to generate a SQL (NVVP) file. This file can be opened with NVVP, as usual.\n    ```sh\n    # If you used profiler.start() and profiler.stop() in net.py\n    nvprof -f -o net.sql --profile-from-start off -- python net.py\n\n    # Profile everything\n    nvprof -f -o net.sql -- python net.py\n    ```\n\n**Note:** if you're experiencing issues with hardware counters and you get a message such as `**_ERR_NVGPUCTRPERM The user running <tool_name/application_name> does not have permission to access NVIDIA GPU Performance Counters on the target device_**`, please follow the steps described in [Hardware Counters](#hardware-counters).\n\n3. Run parser on the SQL file. The output is an ASCII file. Each line\nis a python dictionary which contains information about the kernel name,\nduration, parameters etc. This file can be used as input to other custom\nscripts as well.\n\n    ```sh\n    python -m apex.pyprof.parse net.sql > net.dict\n    ```\n\n4. Run the profiler. The input is the python dictionary created above. The tool can produce a CSV output, a columnated output (similar to `column -t` for terminal readability) and a space separated output (for post processing by AWK for instance). The tool produces 20 columns of information for every GPU kernel but you can select a subset of columns using the `-c` flag. Note that a few columns might have the value \"na\" implying either its a work in progress or the tool was unable to extract that information. Assuming the directory is `prof`, here are a few examples of how to use `prof.py`.\n\n    ```sh\n\t# Print usage and help. Lists all available output columns.\n    python -m apex.pyprof.prof -h\n\n\t# Columnated output of width 150 with some default columns.\n    python -m apex.pyprof.prof -w 150 net.dict\n\n\t# CSV output.\n    python -m apex.pyprof.prof --csv net.dict\n\n\t# Space seperated output.\n    python -m apex.pyprof.prof net.dict\n\n\t# Columnated output of width 130 with columns index,direction,kernel name,parameters,silicon time.\n    python -m apex.pyprof.prof -w 130 -c idx,dir,kernel,params,sil net.dict\n\n\t# CSV output with columns index,direction,kernel name,parameters,silicon time.\n    python -m apex.pyprof.prof --csv -c idx,dir,kernel,params,sil net.dict\n\n\t# Space separated output with columns index,direction,kernel name,parameters,silicon time.\n    python -m apex.pyprof.prof -c idx,dir,kernel,params,sil net.dict\n\n\t# Input redirection.\n    python -m apex.pyprof.prof < net.dict\n    ```\n\n5. Profile-guided optimization\n\nIf kernels that do matrix multiplication/GEMM or convolution use half precision (fp16) data but do not use Tensor Cores (the TC column in the profile analysis output doesn't show a \"1\"), one can follow some basic steps to increase the likelihood that a Tensor Core-compatible kernel will be chosen. For example, for GEMMs, M, N and K should be divisible by 8, and for convolutions, the number of input and output channels shuold be divisible by 8. For more information, see detailed Tensor Core guides such as:\n- Blog Post: [Tips for Optimizing GPU Performance Using Tensor Cores](https://devblogs.nvidia.com/optimizing-gpu-performance-tensor-cores/)\n- GTC Talk: [Tensor Core Deep Learning Performance Guide](https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf)\n\nFor both Tensor Core and non-Tensor Core Deep Learning performance optimization tips, see NVIDIA's [Deep Learning Performance Guide](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html).\n\n### TODOs\n1. The support for conv transpose is currently missing.\n2. PyProf currently works only with NvProf, but Nsight Compute support will be added in the future.\n\n### Example\n\n1. Run `nvprof` on the LeNet model in `examples/lenet.py`. This will output a SQL file called `net.sql`.\n\n```sh\nnvprof -f -o net.sql --profile-from-start off -- python examples/lenet.py\n```\n\n**Note**: DO NOT add --analysis-metrics since that will change which table nvprof writes the kernels to (`CUPTI_ACTIVITY_KIND_KERNEL` instead of the usual `CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL`). Support for running with metrics may be added in the future.\n\nIf you don't care about a full correlation analysis and you'd just like to view the timeline with detailed NVTX annotations, you can do so, e.g. in the NVIDIA Visual Profiler (NVVP). For example, you can call `nvvp net.sql` to view the annotated timeline.\n\n2. Run the `parse.py` script on `net.sql` to extract kernel and runtime information and\nsave it as `net.dict`.\n\n```sh\npython -m apex.pyprof.parse net.sql > net.dict\n```\n\nThis will produce a text file, which can be parsed by any external tool, but it can also be directly read one line at a time by Python by calling `eval` on the line being read. \n\n**Note: you do not need to process this output manually.**  Here the output is just shown as an example of modularity - you can process the raw data yourself, or let the next step enrich the information further and dump a CSV.\n\nThe output of this step will look as follows. Note that the dictionary has a lot more keys than the ones shown in the example.\n\n```\n>>> with open('torchvision.resnet50.adam.64.dict') as f:\n...     for line in f:\n...         d = eval(line)\n...         print(d['kShortName'], d['op'], d['kDuration'], d['block'], d['grid'], d['device'], d['stream'], d['trace'])\n... \nnchwToNhwc3To4Kernel ['conv2d'] 376324 (256, 1, 1) (1568, 1, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\ngeneric4Channel_kernel ['conv2d'] 10720 (512, 1, 1) (19, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\nfirst_layer_fwd_kernel ['conv2d'] 411204 (128, 1, 1) (2, 7, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\nnhwcToNchwKernel ['conv2d'] 342371 (256, 1, 1) (392, 2, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\nelementwise_kernel ['__iadd__'] 2816 (128, 1, 1) (1, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:196']\nbatch_norm_collect_statistics_kernel ['batch_norm', 'batch_norm'] 929513 (512, 1, 1) (64, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:196']\n```\n\n3. Run the `prof.py` script on `net.dict` to summarize the results into a CSV file, or to display the pretty-printed results on the screen. This step processes the raw output from step 2 to generate a nice output, but it also adds a lot of extra useful information inferred from the previous step, such as:\n- FLOPs\n- bandwidth (bytes in and out of GPU DRAM)\n- tensor core usage\n\n```sh\npython -m apex.pyprof.prof --csv net.dict > results.csv\n```\n\nYou can choose which columns you'd like to display. Here's a list from calling `python -m apex.pyprof.prof -h`:\n\n```\n              idx:      Index\n              seq:      PyTorch Sequence Id\n              altseq:   PyTorch Alternate Sequence Id\n              tid:      Thread Id\n              layer:    User annotated NVTX string (can be nested)\n              trace:    Function Call Trace\n              dir:      Direction\n              sub:      Sub Sequence Id\n              mod:      Module\n              op:       Operation\n              kernel:   Kernel Name\n              params:   Parameters\n              sil:      Silicon Time (in ns)\n              tc:       Tensor Core Usage\n              device:   GPU Device Id\n              stream:   Stream Id\n              grid:     Grid Dimensions\n              block:    Block Dimensions\n              flops:    Floating point ops (FMA = 2 FLOPs)\n              bytes:    Number of bytes in and out of DRAM\n```              \n\nLet's have a look at the pretty-printed output:\n```\npython -m apex.pyprof.prof -w 100 -c kernel,op,sil,tc,flops,bytes,device,stream,block,grid torchvision.resnet50.adam.64.dict\n\nKernel              Op              Sil(ns)    TC FLOPs        Bytes        Dev Str Block        Grid         \nelementwise_kernel  relu                381028 -      51380224    205520896   0   7 512,1,1      100352,1,1   \nvolta_fp16_s884cudn conv2d              160002 1    1644167168     51388416   0   7 256,1,1      784,1,1      \nelementwise_kernel  relu                 96545 -      12845056     51380224   0   7 512,1,1      25088,1,1    \nvolta_fp16_s884cudn conv2d              346083 1    6576668672    128483328   0   7 256,1,1      784,2,1      \n```\n\nNot using the pretty-print width (`-w`) option and adding `--csv` results in a CSV output instead:\n\n```\npython -m apex.pyprof.prof --csv -c kernel,mod,op,dir,sil,tc,flops,bytes,device,stream,block,grid torchvision.resnet50.adam.64.dict\n\n\"Kernel\",\"Module\",\"Op\",\"Direction\",\"Sil(ns)\",\"TC\",\"FLOPs\",\"Bytes\",\"Device\",\"Stream\",\"Block\",\"Grid\"\n\"nchwToNhwc3To4Kernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"376324\",\"-\",\"0\",\"0\",\"0\",\"7\",\"256,1,1\",\"1568,1,64\"\n\"generic4Channel_kernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"10720\",\"-\",\"0\",\"0\",\"0\",\"7\",\"512,1,1\",\"19,1,1\"\n\"first_layer_fwd_kernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"411204\",\"-\",\"0\",\"0\",\"0\",\"7\",\"128,1,1\",\"2,7,64\"\n\"nhwcToNchwKernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"342371\",\"-\",\"0\",\"0\",\"0\",\"7\",\"256,1,1\",\"392,2,64\"\n\"elementwise_kernel\",\"Tensor\",\"__iadd__\",\"fprop\",\"2816\",\"-\",\"1.0\",\"8\",\"0\",\"7\",\"128,1,1\",\"1,1,1\"\n\"batch_norm_collect_statistics_kernel\",\"torch.nn.functional\",\"batch_norm\",\"fprop\",\"929513\",\"-\",\"411041792\",\"411041792\",\"0\",\"7\",\"512,1,1\",\"64,1,1\"\n\"batch_norm_transform_input_kernel\",\"torch.nn.functional\",\"batch_norm\",\"fprop\",\"377539\",\"-\",\"411041792\",\"411041792\",\"0\",\"7\",\"512,1,1\",\"64,64,1\"\n\"elementwise_kernel\",\"torch.nn.functional\",\"relu\",\"fprop\",\"381028\",\"-\",\"51380224\",\"205520896\",\"0\",\"7\",\"512,1,1\",\"100352,1,1\"\n\"MaxPoolForward\",\"torch.nn.functional\",\"max_pool2d\",\"fprop\",\"406531\",\"-\",\"0\",\"0\",\"0\",\"7\",\"256,1,1\",\"50176,1,1\"\n\"cudnn::gemm::computeOffsetsKernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"2464\",\"-\",\"0\",\"0\",\"0\",\"7\",\"128,1,1\",\"25,1,1\"\n```\n\n### Hardware Counters\n\nProfiling GPU workloads may require access to [hardware performance counters]([https://en.wikipedia.org/wiki/Hardware_performance_counter](https://en.wikipedia.org/wiki/Hardware_performance_counter)). Due to a [fix](https://nvidia.custhelp.com/app/answers/detail/a_id/4738) in recent NVIDIA drivers addressing [CVE‑2018‑6260](https://nvd.nist.gov/vuln/detail/CVE-2018-6260), the hardware counters are disabled by default, and require elevated privileges to be enabled again. If you're using a recent driver, you may see the following message when trying to run nvprof:\n\n```**_ERR_NVGPUCTRPERM The user running <tool_name/application_name> does not have permission to access NVIDIA GPU Performance Counters on the target device._**```\n\nFor details, see [here](https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters).\n\n_Permanent solution_\n\nFollow the steps [here]([https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters](https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters)). The current steps for Linux are:\n```\nsudo systemctl isolate multi-user\nsudo modprobe -r nvidia_uvm nvidia_drm nvidia_modeset nvidia-vgpu-vfio nvidia\nsudo modprobe nvidia NVreg_RestrictProfilingToAdminUsers=0\nsudo systemctl isolate graphical\n```\nThe above steps should result in a permanent change.\n\n_Temporary solution_\n\nWhen running on bare metal, you can run nvprof with `sudo`.\n\nIf you're running in a Docker image, you can temporarily elevate your privileges with one of the following (oldest to newest syntax):\n<pre>\nnvidia-docker run <b>--privileged</b>\ndocker run --runtime nvidia <b>--privileged</b>\ndocker run --gpus all <b>--privileged<b>\n</pre>\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/__init__.py",
    "content": "import warnings\n\nfrom . import nvtx, prof\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/.gitignore",
    "content": "__pycache__\n*.sql\n*.dict\n*.csv\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/apex/README.md",
    "content": "This directory has examples of how to use `pyprof` with APEX extensions e.g. `fused_adam_cuda` and `fused_layer_norm_cuda`.\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/apex/fused_adam.py",
    "content": "import torch\nimport fused_adam_cuda\nfrom apex.optimizers import FusedAdam, FP16_Optimizer\nfrom apex import pyprof\n\npyprof.nvtx.init()\npyprof.nvtx.wrap(fused_adam_cuda, 'adam')\n\nmodel = torch.nn.Linear(10, 20).cuda().half()\ncriterion = torch.nn.CrossEntropyLoss().cuda()\noptimizer = FusedAdam(model.parameters())\noptimizer = FP16_Optimizer(optimizer)\n\nx = torch.ones(32, 10).cuda().half()\ntarget = torch.empty(32, dtype=torch.long).random_(20).cuda()\ny = model(x)\nloss = criterion(y, target)\noptimizer.zero_grad()\nloss.backward()\noptimizer.step()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/apex/fused_layer_norm.py",
    "content": "import torch\nimport fused_layer_norm_cuda\nfrom apex.normalization import FusedLayerNorm\nfrom apex import pyprof\n\npyprof.nvtx.init()\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'forward')\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'backward')\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'forward_affine')\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'backward_affine')\n\ninput = torch.randn(20, 5, 10, 10).cuda()\n\n# With Learnable Parameters\nm = FusedLayerNorm(input.size()[1:]).cuda()\noutput = m(input)\n\n# Without Learnable Parameters\nm = FusedLayerNorm(input.size()[1:], elementwise_affine=False).cuda()\noutput = m(input)\n\n# Normalize over last two dimensions\nm = FusedLayerNorm([10, 10]).cuda()\noutput = m(input)\n\n# Normalize over last dimension of size 10\nm = FusedLayerNorm(10).cuda()\noutput = m(input)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/apex/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo $sql python $f\"\n\tnvprof -fo $sql python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t$prof -w 130 $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/custom_func_module/README.md",
    "content": "This directory has examples which show how to intercept (monkey patch) custom functions and modules with `pyprof`. No changes are required in `pyprof/parse`, however, users can add support for bytes and flops calculation for custom functions and modules in `pyprof/prof` by extending the `OperatorLayerBase` class.\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/custom_func_module/custom_function.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n#Initialize pyprof\npyprof.nvtx.init()\n\nclass Foo(torch.autograd.Function):\n\t@staticmethod\n\tdef forward(ctx, in1, in2):\n\t\tout = in1 + in2\t\t#This could be a custom C/C++ function.\n\t\treturn out\n\n\t@staticmethod\n\tdef backward(ctx, grad):\n\t\tin1_grad = grad\t\t#This could be a custom C/C++ function.\n\t\tin2_grad = grad\t\t#This could be a custom C/C++ function.\n\t\treturn in1_grad, in2_grad\n\n#Hook the forward and backward functions to pyprof\npyprof.nvtx.wrap(Foo, 'forward')\npyprof.nvtx.wrap(Foo, 'backward')\n\nfoo = Foo.apply\n\nx = torch.ones(4,4).cuda()\ny = torch.ones(4,4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x,y)\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/custom_func_module/custom_module.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\npyprof.nvtx.init()\n\nclass Foo(torch.nn.Module):\n    def __init__(self, size):\n        super(Foo, self).__init__()\n        self.n = torch.nn.Parameter(torch.ones(size))\n        self.m = torch.nn.Parameter(torch.ones(size))\n\n    def forward(self, input):\n        return self.n*input + self.m\n\n#Hook the forward function to pyprof\npyprof.nvtx.wrap(Foo, 'forward')\n\nfoo = Foo(4)\nfoo.cuda()\nx = torch.ones(4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x)\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/custom_func_module/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo $sql python $f\"\n\tnvprof -fo $sql python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t$prof -w 130 $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/imagenet/imagenet.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nExample to run pyprof with imagenet models.\n\"\"\"\n\nimport sys\nimport torch\nimport torch.nn as nn\nimport torchvision.models as models\nimport torch.cuda.profiler as profiler\nimport argparse\n\nfrom apex import pyprof\nfrom apex.optimizers import FusedAdam\n\ndef parseArgs():\n\tparser = argparse.ArgumentParser(prog=sys.argv[0], description=\"Run popular imagenet models.\")\n\n\tparser.add_argument(\"-m\",\n\t\ttype=str,\n\t\tdefault=\"resnet50\",\n\t\tchoices=[\"alexnet\", \"densenet121\", \"densenet161\", \"densenet169\", \"densenet201\", \"googlenet\", \"mnasnet0_5\", \"mnasnet0_75\", \"mnasnet1_0\", \"mnasnet1_3\", \"mobilenet_v2\", \"resnet18\", \"resnet34\", \"resnet50\", \"resnet101\", \"resnet152\", \"resnext50_32x4d\", \"resnext101_32x8d\", \"wide_resnet50_2\", \"wide_resnet101_2\", \"shufflenet_v2_x0_5\", \"shufflenet_v2_x1_0\", \"shufflenet_v2_x1_5\", \"shufflenet_v2_x2_0\", \"squeezenet1_0\", \"squeezenet1_1\", \"vgg11\", \"vgg11_bn\", \"vgg13\", \"vgg13_bn\", \"vgg16\", \"vgg16_bn\", \"vgg19\", \"vgg19_bn\", \"inception_v3\"],\n\t\thelp=\"Model.\")\n\n\tparser.add_argument(\"-b\",\n\t\ttype=int,\n\t\tdefault=32,\n\t\thelp=\"Batch size.\")\n\n\tparser.add_argument(\"-o\",\n\t\ttype=str,\n\t\tdefault=\"adam\",\n\t\tchoices=[\"adam\", \"sgd\"],\n\t\thelp=\"Optimizer.\")\n\n\targs = parser.parse_args()\n\treturn args\n\nd = {\n\t\"alexnet\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"densenet121\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"densenet161\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"densenet169\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"densenet201\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"googlenet\":\t\t\t{'H': 224, 'W': 224, 'opts': {'aux_logits': False}},\n\n\t\"mnasnet0_5\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"mnasnet0_75\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"mnasnet1_0\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"mnasnet1_3\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"mobilenet_v2\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"resnet18\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet34\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet50\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet101\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet152\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"resnext50_32x4d\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnext101_32x8d\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"wide_resnet50_2\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"wide_resnet101_2\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"shufflenet_v2_x0_5\": \t{'H': 224, 'W': 224, 'opts': {}},\n\t\"shufflenet_v2_x1_0\": \t{'H': 224, 'W': 224, 'opts': {}},\n\t\"shufflenet_v2_x1_5\": \t{'H': 224, 'W': 224, 'opts': {}},\n\t\"shufflenet_v2_x2_0\":\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"squeezenet1_0\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"squeezenet1_1\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"vgg11\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg11_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg13\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg13_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg16\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg16_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg19\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg19_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"inception_v3\":\t\t\t{'H': 299, 'W': 299, 'opts': {'aux_logits': False}},\n\t}\n\ndef main():\n\targs = parseArgs()\n\n\tpyprof.nvtx.init()\n#\tpyprof.nvtx.wrap(fused_adam_cuda, 'adam')\n\n\tN = args.b\n\tC = 3\n\tH = d[args.m]['H']\n\tW = d[args.m]['W']\n\topts = d[args.m]['opts']\n\tclasses = 1000\n\n\tnet = getattr(models, args.m)\n\tnet = net(**opts).cuda().half()\n\tnet.train()\n\n\tx = torch.rand(N, C, H, W).cuda().half()\n\ttarget = torch.empty(N, dtype=torch.long).random_(classes).cuda()\n\n\tcriterion = nn.CrossEntropyLoss().cuda()\n\tif (args.o == \"sgd\"):\n\t\toptimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n\telif (args.o == \"adam\"):\n\t\toptimizer = FusedAdam(net.parameters())\n\telse:\n\t\tassert False\n\n\t#Warm up without profiler\n\tfor i in range(2):\n\t\toutput = net(x)\n\t\tloss = criterion(output, target)\n\t\toptimizer.zero_grad()\n\t\tloss.backward()\n\t\toptimizer.step()\n\n\twith torch.autograd.profiler.emit_nvtx():\n\t\tprofiler.start()\n\t\toutput = net(x)\n\t\tloss = criterion(output, target)\n\t\toptimizer.zero_grad()\n\t\tloss.backward()\n\t\toptimizer.step()\n\t\tprofiler.stop()\n\nif __name__ == \"__main__\":\n\tmain()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/imagenet/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python -m apex.pyprof.parse\"\nprof=\"python -m apex.pyprof.prof\"\n\nfor net in \"resnet50\"\ndo\n\tfor optim in adam sgd\n\tdo\n\t\tfor batch in 32 64\n\t\tdo\n\t\t\tbase=\"torchvision\".$net.$optim.$batch\n\t\t\tsql=$base.sql\n\t\t\tdict=$base.dict\n\n\t\t\t#NVprof\n\t\t\techo \"nvprof -fo $sql --profile-from-start off python imagenet.py -m ${net} -o $optim -b $batch\"\n\t\t\tnvprof -fo $sql --profile-from-start off python imagenet.py -m ${net} -o $optim -b $batch\n\n\t\t\t#Parse\n\t\t\techo $parse $sql\n\t\t\t$parse $sql > $dict\n\n\t\t\t#Prof\n\t\t\techo $prof $dict\n\t\t\t$prof -w 130 $dict\n#\t\t\t\\rm $sql $dict\n\t\tdone\n\tdone\ndone\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/jit/README.md",
    "content": "*As of this writing, these examples do not work\nbecause of changes being proposed in PyTorch.*\n\nThere are two ways to use PyTorch JIT\n - Scripting\n - Tracing\n\nIn addition, we can JIT a\n - Stand alone function\n - Class / class method\n\nThis directory has an example for each of the 4 cases.\nIntercepting (monkey patching) JITted code has a few extra steps,\nwhich are explained through comments.\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/jit/jit_script_function.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\n#The following creates an object \"foo\" of type ScriptModule\n#The new object has a function called \"forward\"\n\n@torch.jit.script\ndef foo(x, y):\n\treturn torch.sigmoid(x) + y\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Assign a name to the object \"foo\"\nfoo.__name__ = \"foo\"\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(foo, 'forward')\n\nx = torch.zeros(4,4).cuda()\ny = torch.ones(4,4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x, y)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/jit/jit_script_method.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\nclass Foo(torch.jit.ScriptModule):\n    def __init__(self, size):\n        super(Foo, self).__init__()\n        self.n = torch.nn.Parameter(torch.ones(size))\n        self.m = torch.nn.Parameter(torch.ones(size))\n\n    @torch.jit.script_method\n    def forward(self, input):\n        return self.n*input + self.m\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(Foo, 'forward')\n\nfoo = Foo(4)\nfoo.cuda()\nx = torch.ones(4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/jit/jit_trace_function.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\ndef foo(x, y):\n\treturn torch.sigmoid(x) + y\n\nx = torch.zeros(4,4).cuda()\ny = torch.ones(4,4).cuda()\n\n#JIT the function using tracing\n#This returns an object of type ScriptModule with a forward method.\ntraced_foo = torch.jit.trace(foo, (x,y))\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Assign a name to the object \"traced_foo\"\ntraced_foo.__dict__['__name__'] = \"foo\"\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(traced_foo, 'forward')\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = traced_foo(x, y)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/jit/jit_trace_method.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\nclass Foo(torch.nn.Module):\n    def __init__(self, size):\n        super(Foo, self).__init__()\n        self.n = torch.nn.Parameter(torch.ones(size))\n        self.m = torch.nn.Parameter(torch.ones(size))\n\n    def forward(self, input):\n        return self.n*input + self.m\n\nfoo = Foo(4)\nfoo.cuda()\nx = torch.ones(4).cuda()\n\n#JIT the class using tracing\ntraced_foo = torch.jit.trace(foo, x)\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Assign a name to the object \"traced_foo\"\ntraced_foo.__dict__['__name__'] = \"foo\"\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(traced_foo, 'forward')\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = traced_foo(x)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/jit/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo $sql python $f\"\n\tnvprof -fo $sql python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t$prof -w 130 $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/lenet.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.cuda.profiler as profiler\nimport torch.optim as optim\n\nfrom apex import pyprof\npyprof.nvtx.init()\n\nclass LeNet5(nn.Module):\n\tdef __init__(self):\n\t\tsuper(LeNet5, self).__init__()\n\t\t# 1 input image channel, 6 output channels, 5x5 square convolution\n\t\t# kernel\n\t\tself.conv1 = nn.Conv2d(1, 6, 5)\n\t\tself.conv2 = nn.Conv2d(6, 16, 5)\n\t\t# an affine operation: y = Wx + b\n\t\tself.fc1 = nn.Linear(16 * 5 * 5, 120)\n\t\tself.fc2 = nn.Linear(120, 84)\n\t\tself.fc3 = nn.Linear(84, 10)\n\n\tdef forward(self, x):\n\t\t# Max pooling over a (2, 2) window\n\t\tx = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n\t\t# If the size is a square you can only specify a single number\n\t\tx = F.max_pool2d(F.relu(self.conv2(x)), 2)\n\t\tx = x.view(-1, self.num_flat_features(x))\n\t\tx = F.relu(self.fc1(x))\n\t\tx = F.relu(self.fc2(x))\n\t\tx = self.fc3(x)\n\t\treturn x\n\n\tdef num_flat_features(self, x):\n\t\tsize = x.size()[1:]  # all dimensions except the batch dimension\n\t\tnum_features = 1\n\t\tfor s in size:\n\t\t\tnum_features *= s\n\t\treturn num_features\n\nwith torch.autograd.profiler.emit_nvtx():\n\n\tnet = LeNet5().cuda()\n\n\tinput = torch.randn(1, 1, 32, 32).cuda()\n\tout = net(input)\n\n\ttarget = torch.randn(10)\t\t\t# a dummy target, for example\n\ttarget = target.view(1, -1).cuda()\t# make it the same shape as output\n\tcriterion = nn.MSELoss()\n\n\t# create your optimizer\n\toptimizer = optim.SGD(net.parameters(), lr=0.01)\n\n\t# in your training loop:\n\toptimizer.zero_grad()\t# zero the gradient buffers\n\n\tprofiler.start()\n\toutput = net(input)\n\tloss = criterion(output, target)\n\tloss.backward()\n\toptimizer.step()\t# Does the update\n\tprofiler.stop()\n\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/operators.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nThis file checks all Python operators.\n\"\"\"\n\nimport sys\nimport torch\nimport torch.cuda.profiler as profiler\nimport operator\nimport inspect\n\n#Import and initialize pyprof\nfrom apex import pyprof\npyprof.nvtx.init()\n\nX = 1024\nY = 1024\n\nfa = torch.rand(X, Y).cuda()\nfb = torch.rand(X, Y).cuda()\nfc = torch.rand(X, Y).cuda()\n\nia = torch.randint(0, 100, (X, Y)).cuda()\nib = torch.randint(0, 100, (X, Y)).cuda()\n\nsa = torch.ones(1,1).cuda()\nsb = torch.ones(1,1).cuda()\n\nba = fa.byte()\n\nunaryOps = [\"abs\", \"__abs__\", \"neg\", \"__neg__\",]\ninvertOps = [\"inv\", \"invert\", \"__inv__\", \"__invert__\",]\t#imlemented only for byte tensors\n#pos, __pos__ is not implemented for tensors\n\nbinaryOps = []\nbinaryOps += [ \"lt\", \"__lt__\", \"le\", \"__le__\", \"eq\", \"__eq__\", \"ne\", \"__ne__\", \"ge\", \"__ge__\", \"gt\", \"__gt__\" ]\nbinaryOps += [ \"add\", \"__add__\", \"sub\", \"__sub__\", \"mul\", \"__mul__\", \"floordiv\", \"__floordiv__\", \"truediv\", \"__truediv__\", \"pow\", \"__pow__\", \"mod\", \"__mod__\"]\nbinaryOps += [ \"and_\", \"__and__\", \"or_\", \"__or__\", \"xor\", \"__xor__\", \"lshift\", \"__lshift__\", \"rshift\", \"__rshift__\"]\n\ninplaceOps = []\ninplaceOps += [\"iadd\", \"__iadd__\", \"isub\", \"__isub__\", \"imul\", \"__imul__\", \"ifloordiv\", \"__ifloordiv__\", \"itruediv\", \"__itruediv__\", \"imod\", \"__imod__\",]\n#ipow, __ipow__ is not implemented in pytorch\ninplaceOps += [ \"iand\", \"__iand__\", \"ior\", \"__ior__\", \"ixor\", \"__ixor__\", \"ilshift\", \"__ilshift__\", \"irshift\", \"__irshift__\",]\n\nmatmulOps = [ \"matmul\", \"__matmul__\" ]\ninplacematmulOps = [ \"imatmul\", \"__imatmul__\" ]\n\nreverseIntBinaryOps = [\"__radd__\", \"__rsub__\", \"__rmul__\", \"__rfloordiv__\", \"__rpow__\",]\nreverseFloatBinaryOps = [\"__radd__\", \"__rsub__\", \"__rmul__\", \"__rdiv__\", \"__rtruediv__\", \"__rfloordiv__\", \"__rpow__\",]\n\n'''\nTODO\n.concat(a, b)\n.__concat__(a, b)\n.contains(a, b)\n.__contains__(a, b)\n.countOf(a, b)\n.delitem(a, b)\n.__delitem__(a, b)\n.getitem(a, b)\n.__getitem__(a, b)\n.indexOf(a, b)\n.setitem(a, b, c)\n.__setitem__(a, b, c)\n.length_hint(obj, default=0)\n.iconcat(a, b)\n.__iconcat__(a, b)\n.index(a)\n.__index__(a)\n'''\n\n#Context manager\nwith torch.autograd.profiler.emit_nvtx():\n\n\t#Start profiler\n\tprofiler.start()\n\n\tfor op in unaryOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(ia)\n\n\tfor op in invertOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(ba)\n\n\tfor op in binaryOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(ia, ib)\n\t\tc = f(ia, 2)\n\n\tfor op in inplaceOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tia = f(ia, ib)\n\t\tia = f(ia, 2)\n\n\tfor op in matmulOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(fa, fb)\n\n\tfor op in inplacematmulOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tfa = f(fa, fb)\n\n\tfor op in reverseIntBinaryOps:\n\t\tassert hasattr(torch.Tensor, op)\n\t\tf = getattr(torch.Tensor, op)\n\t\tia = f(ia, ib)\n\n\tfor op in reverseFloatBinaryOps:\n\t\tassert hasattr(torch.Tensor, op)\n\t\tf = getattr(torch.Tensor, op)\n\t\tfa = f(fa, fb)\n\n\t'''\n\t#c = fa[3]\n\t#c = fa[3][3]\n\t#c = torch.min(fa, 3)\n\tc = torch.sum(fa)\n\tc = torch.max(fa)\n\tc = -fa\n\t#fc[2][2] = fa[2][2]\n\n\tc = a_scalar and b_scalar\n\tc = a_scalar or b_scalar\n\tc = not a_scalar\n\n\tc = a is b\n\tc = a is not b\n\t'''\n\n\t#Stop profiler\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/simple.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nThis simple file provides an example of how to\n - import the pyprof library and initialize it\n - use the emit_nvtx context manager\n - start and stop the profiler\n\nOnly kernels within profiler.start and profiler.stop calls are profiled.\nTo profile\n$ nvprof -f -o simple.sql --profile-from-start off ./simple.py\n\"\"\"\n\nimport sys\nimport torch\nimport torch.cuda.profiler as profiler\n\n#Import and initialize pyprof\nfrom apex import pyprof\npyprof.nvtx.init()\n\na = torch.randn(5, 5).cuda()\nb = torch.randn(5, 5).cuda()\n\n#Context manager\nwith torch.autograd.profiler.emit_nvtx():\n\n\t#Start profiler\n\tprofiler.start()\n\n\tc = a + b\n\tc = torch.mul(a,b)\n\tc = torch.matmul(a,b)\n\tc = torch.argmax(a, dim=1)\n\tc = torch.nn.functional.pad(a, (1,1))\n\n\t#Stop profiler\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/user_annotation/README.md",
    "content": "Nvidia NVTX range markers (https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) \nare a useful tool to capture and observe events and code ranges etc. \nUsing PyTorch APIs e.g, `torch.cuda.nvtx.range_push(\"xxx\")` and `torch.cuda.nvtx.range_pop()` users can easily add their own NVTX range markers. These markers can then be observed in the Nvidia Visual Profiler (NVVP).\n\nWhile inserting NVTX markers (strings), if the users follow a specific string pattern `\"layer:your_string_here\"` e.g. `\"layer:conv1\"` or `\"layer:encoder_layer_3_self_attention`, then `pyprof` will display the strings `conv1` and `encoder_layer_3_self_attention` next to the associated kernels in the output of `prof.py` when used with the `-c layer` option.\n\nNVTX range markers can be nested and if users follow the above string pattern, the output of `prof.py` will show all the markers associated with a kernel.\n\nThe file `resnet.py` (a simplified version of the torchvision model) shows an example of how users can add (nested) NVTX markers with information which can greatly aid in understanding and analysis of networks.\n\nNote that the pattern `\"layer:your_string_here\"` was chosen to aid information extraction by `pyprof`. The tool will work seamlessly even if there are other markers or no markers at all.\n\n### To run\n\n```sh\nnvprof -fo resnet.sql --profile-from-start off python resnet.py\nparse.py resnet.sql > resnet.dict\nprof.py --csv -c idx,layer,dir,mod,op,kernel,params,sil resnet.dict\n```\n\nThe file `resnet.sql` can also be opened with NVVP as usual.\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/user_annotation/resnet.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nAn example showing use of nested NVTX markers.\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\nimport torch.cuda.profiler as profiler\nimport torch.cuda.nvtx as nvtx\nfrom apex import pyprof\npyprof.nvtx.init()\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n\t\"\"\"3x3 convolution with padding\"\"\"\n\treturn nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n\t\t\t\t\t padding=dilation, groups=groups, bias=False, dilation=dilation)\n\ndef conv1x1(in_planes, out_planes, stride=1):\n\t\"\"\"1x1 convolution\"\"\"\n\treturn nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\nclass Bottleneck(nn.Module):\n\texpansion = 4\n\tcount = 1\n\n\tdef __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n\t\t\t\t base_width=64, dilation=1, norm_layer=None):\n\t\tsuper(Bottleneck, self).__init__()\n\t\tif norm_layer is None:\n\t\t\tnorm_layer = nn.BatchNorm2d\n\t\twidth = int(planes * (base_width / 64.)) * groups\n\t\t# Both self.conv2 and self.downsample layers downsample the input when stride != 1\n\t\tself.conv1 = conv1x1(inplanes, width)\n\t\tself.bn1 = norm_layer(width)\n\t\tself.conv2 = conv3x3(width, width, stride, groups, dilation)\n\t\tself.bn2 = norm_layer(width)\n\t\tself.conv3 = conv1x1(width, planes * self.expansion)\n\t\tself.bn3 = norm_layer(planes * self.expansion)\n\t\tself.relu = nn.ReLU(inplace=True)\n\t\tself.downsample = downsample\n\t\tself.stride = stride\n\n\t\tself.id = Bottleneck.count\n\t\tBottleneck.count += 1\n\n\tdef forward(self, x):\n\t\tidentity = x\n\n\t\tnvtx.range_push(\"layer:Bottleneck_{}\".format(self.id))\n\n\t\tnvtx.range_push(\"layer:Conv1\")\n\t\tout = self.conv1(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:BN1\")\n\t\tout = self.bn1(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:ReLU\")\n\t\tout = self.relu(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:Conv2\")\n\t\tout = self.conv2(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:BN2\")\n\t\tout = self.bn2(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:ReLU\")\n\t\tout = self.relu(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:Conv3\")\n\t\tout = self.conv3(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:BN3\")\n\t\tout = self.bn3(out)\n\t\tnvtx.range_pop()\n\n\t\tif self.downsample is not None:\n\t\t\tnvtx.range_push(\"layer:Downsample\")\n\t\t\tidentity = self.downsample(x)\n\t\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:Residual\")\n\t\tout += identity\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:ReLU\")\n\t\tout = self.relu(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_pop()\n\n\t\treturn out\n\nclass ResNet(nn.Module):\n\n\tdef __init__(self, block, layers, num_classes=1000,\n\t\t\t\t groups=1, width_per_group=64, norm_layer=None):\n\t\tsuper(ResNet, self).__init__()\n\t\tif norm_layer is None:\n\t\t\tnorm_layer = nn.BatchNorm2d\n\t\tself._norm_layer = norm_layer\n\n\t\tself.inplanes = 64\n\t\tself.dilation = 1\n\n\t\tself.groups = groups\n\t\tself.base_width = width_per_group\n\t\tself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n\t\tself.bn1 = norm_layer(self.inplanes)\n\t\tself.relu = nn.ReLU(inplace=True)\n\t\tself.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\t\tself.layer1 = self._make_layer(block, 64, layers[0])\n\t\tself.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n\t\tself.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n\t\tself.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n\t\tself.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n\t\tself.fc = nn.Linear(512 * block.expansion, num_classes)\n\n\t\tfor m in self.modules():\n\t\t\tif isinstance(m, nn.Conv2d):\n\t\t\t\tnn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n\t\t\telif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n\t\t\t\tnn.init.constant_(m.weight, 1)\n\t\t\t\tnn.init.constant_(m.bias, 0)\n\n\tdef _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n\t\tnorm_layer = self._norm_layer\n\t\tdownsample = None\n\t\tprevious_dilation = self.dilation\n\t\tif dilate:\n\t\t\tself.dilation *= stride\n\t\t\tstride = 1\n\t\tif stride != 1 or self.inplanes != planes * block.expansion:\n\t\t\tdownsample = nn.Sequential(\n\t\t\t\tconv1x1(self.inplanes, planes * block.expansion, stride),\n\t\t\t\tnorm_layer(planes * block.expansion),\n\t\t\t)\n\n\t\tlayers = []\n\t\tlayers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n\t\t\t\t\t\t\tself.base_width, previous_dilation, norm_layer))\n\t\tself.inplanes = planes * block.expansion\n\t\tfor _ in range(1, blocks):\n\t\t\tlayers.append(block(self.inplanes, planes, groups=self.groups,\n\t\t\t\t\t\t\t\tbase_width=self.base_width, dilation=self.dilation,\n\t\t\t\t\t\t\t\tnorm_layer=norm_layer))\n\n\t\treturn nn.Sequential(*layers)\n\n\tdef forward(self, x):\n\n\t\tnvtx.range_push(\"layer:conv1_x\")\n\t\tx = self.conv1(x)\n\t\tx = self.bn1(x)\n\t\tx = self.relu(x)\n\t\tx = self.maxpool(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv2_x\")\n\t\tx = self.layer1(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv3_x\")\n\t\tx = self.layer2(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv4_x\")\n\t\tx = self.layer3(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv5_x\")\n\t\tx = self.layer4(x)\n\t\tnvtx.range_pop()\n\n\t\tx = self.avgpool(x)\n\t\tx = torch.flatten(x, 1)\n\n\t\tnvtx.range_push(\"layer:FC\")\n\t\tx = self.fc(x)\n\t\tnvtx.range_pop()\n\n\t\treturn x\n\n\ndef resnet50():\n\treturn ResNet(Bottleneck, [3, 4, 6, 3])\n\n#Create model\nnet = resnet50().cuda().half()\nnet.train()\n\n#Create optimizer\ncriterion = nn.CrossEntropyLoss().cuda()\noptimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n\n#Create synthetic input and label\nx = torch.rand(32, 3, 224, 224).cuda().half()\ntarget = torch.empty(32, dtype=torch.long).random_(1000).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\toutput = net(x)\n\tloss = criterion(output, target)\n\toptimizer.zero_grad()\n\tloss.backward()\n\toptimizer.step()\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/examples/user_annotation/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo --profile-from-start off $sql python $f\"\n\tnvprof -fo $sql --profile-from-start off python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t#$prof -w 130 $dict\n\t$prof --csv -c idx,layer,dir,mod,op,kernel,params,sil $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/nvtx/__init__.py",
    "content": "from .nvmarker import init\nfrom .nvmarker import add_wrapper as wrap\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/nvtx/nvmarker.py",
    "content": "\"\"\"\nThis file intercepts (monkey patches) the following functions and adds NVTX markers.\n\ttorch.*\n\ttorch.Tensor.*\n\ttorch.nn.functional.*\n\ttorch.nn.*.forward\n\nThe NVTX markers (one or more) contain the following information\n\tcall trace (a list of file_name:line_number)\n\textra_repr() from torch.nn modules\n\tmodule/class name\n\tfunction name\n\tinputs (args and kwargs)\n\t\tscalar: name, type and value\n\t\ttensor: name, shape and datatype\n\t\tnumpy: name, shape and datatype\n\t\tlist/tuple: a sequence of scalars or tensors or numpy arrays\n\"\"\"\n\nimport torch\nimport torch.cuda.nvtx as nvtx\nimport numpy\nimport inspect as ins\nimport traceback\nimport math\n\ndef isfunc(mod, f):\n\tassert hasattr(mod, f)\n\tattr = getattr(mod, f)\n\n\t#Ignore functions like _add\n\tif (len(f) >= 2):\n\t\tif f[0] == \"_\" and f[1] != \"_\":\n\t\t\treturn False\n\n\t#Ignore functions from this list\n\tignore = ['__all__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__builtins__', '__cached__', '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__file__', '__format__', '__getattribute__', '__getitem__', '__hash__', '__index__', '__init__', '__init_subclass__', '__iter__', '__len__', '__loader__', '__module__', '__name__', '__new__', '__nonzero__', '__package__', '__path__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__spec__', '__str__', '__subclasshook__', '__version__', '__weakref__']\n\n\t#Add functions to this list if they cause recursion\n\tignore += ['size', 'tolist', 'dim', 'is_storage', 'item']\n\tif f in ignore:\n\t\treturn False\n\n\treturn ins.ismethod(attr) or ins.isfunction(attr) or ins.ismethoddescriptor(attr) or ins.isbuiltin(attr)\n\ndef traceMarker(stack):\n\td = {}\n\tcadena = []\n\tfor i in range(len(stack)-1):\n\t\tfi = stack[i]\n\t\tt = \"{}:{}\".format(fi.filename, fi.lineno)\n\t\tcadena.append(t)\n\td['traceMarker'] = cadena\n\treturn str(d)\n\ndef modMarker(mod, fn_name, args):\n\t\"\"\"\n\tReturns the stringified extra_repr() of a module.\n\t\"\"\"\n\tassert(fn_name == 'forward')\n\tassert(len(args) > 0)\n\td = {}\n\td['mod'] = mod.__name__\n\td['strRepr'] = args[0].extra_repr()\n\treturn str(d)\n\ndef add_wrapper(mod, fn_name):\n\tassert isfunc(mod, fn_name)\n\n\t# Get a pointer to the original function\n\tfunc = getattr(mod, fn_name)\n\n\t# Check if the mod has a string representation\n\t# and is not a Script or Traced module (used by JIT)\n\ts = hasattr(mod, \"extra_repr\") and (type(mod) is not torch.jit.ScriptModule) and (type(mod) is not torch.jit.TopLevelTracedModule)\n\n\tdef wrapper_func(*args, **kwargs):\n\n\t\t# Extract the stacktrace\n\t\tstack = traceback.extract_stack()\n\n\t\t# Push trace marker\n\t\tnvtx.range_push(traceMarker(stack))\n\n\t\t# Push module marker\n\t\tif s:\n\t\t\tm = modMarker(mod, fn_name, args)\n\t\t\tnvtx.range_push(m)\n\n\t\t# Create and push argument marker\n\t\tcadena = argMarker(mod, fn_name, args, kwargs)\n\t\tnvtx.range_push(cadena)\n\n\t\t# Call the original function\n\t\tresult = func(*args, **kwargs)\n\n\t\t# Pop argumet marker\n\t\tnvtx.range_pop()\n\n\t\t# Pop module marker\n\t\tif s:\n\t\t\tnvtx.range_pop()\n\n\t\t# Pop trace marker\n\t\tnvtx.range_pop()\n\n\t\treturn result\n\tsetattr(mod, fn_name, wrapper_func)\n\ndef argMarker(mod, op, args, kwargs):\n\t#For this function args is a tuple and kwargs is a dict\n\n\tdef tensor(arg, name=\"\"):\n\t\ta = {}\n\t\ta['name'] = name\n\t\ta['type'] = \"tensor\"\n\t\ta['shape'] = tuple(arg.size())\n\t\ta['dtype'] = str(arg.dtype).split(\".\")[-1]\n\t\tcadena['args'].append(a)\n\n\tdef ndarray(arg, name=\"\"):\n\t\ta = {}\n\t\ta['name'] = name\n\t\ta['type'] = \"ndarray\"\n\t\ta['shape'] = arg.shape\n\t\ta['dtype'] = str(arg.dtype).split(\".\")[-1]\n\t\tcadena['args'].append(a)\n\n\tdef seq(arg, name=\"\"):\n\t\tassert issequence(arg)\n\t\ta = {}\n\t\ta['name'] = name\n\t\tif isinstance(arg, list):\n\t\t\ta['type'] = \"list\"\n\t\t\ta['value'] = arg\n\t\telse:\n\t\t\ta['type'] = \"tuple\"\n\t\t\t# The arg could be torch.Size, which is a subclass of tuple\n\t\t\t# Therefore, explicitly convert to tuple\n\t\t\ta['value'] = tuple(arg)\n\t\t\n\t\tcadena['args'].append(a)\n\n\tdef scalar(arg, name=\"\"):\n\t\ta = {}\n\t\ta['name'] = name\n\t\ta['type'] = type(arg).__name__\n\t\t#handle the case when the argument is +/- inf or nan\n\t\tif arg == float('inf'):\n\t\t\ta['value'] = \"inf\"\n\t\telif arg == float('-inf'):\n\t\t\ta['value'] = \"-inf\"\n\t\telif isinstance(arg, float) and math.isnan(arg):\n\t\t\ta['value'] = \"nan\"\n\t\telse:\n\t\t\ta['value'] = arg\n\t\tcadena['args'].append(a)\n\n\tdef isscalar(arg):\n\t\treturn (type(arg) is int) or (type(arg) is float) or (type(arg) is bool) or (arg is None) or (type(arg) is str)\n\n\tdef issequence(arg):\n\t\treturn isinstance(arg, list) or isinstance(arg, tuple)\n\n\tdef foo(args, name):\n\t\t#args should be an iterable sequence e.g. list or tuple\n\t\tfor arg in args:\n\t\t\tif isinstance(arg, torch.Tensor):\n\t\t\t\tif arg.dim() == 0:\n\t\t\t\t\tscalar(arg.item(), name)\n\t\t\t\telse:\n\t\t\t\t\ttensor(arg, name)\n\t\t\telif isinstance(arg, numpy.ndarray):\n\t\t\t\tndarray(arg, name)\n\t\t\telif (isscalar(arg)):\n\t\t\t\tscalar(arg, name)\n\t\t\telif issequence(arg):\n\t\t\t\tif (len(arg) == 0) or isscalar(arg[0]):\t#An empty sequence or a sequence of scalars\n\t\t\t\t\tseq(arg, name)\n\t\t\t\telse:\t# A sequence of tensors or numpy arrays\n\t\t\t\t\tfoo(arg, name)\n\t\t\t'''\n\t\t\telse:\n\t\t\t\tprint(\"The following arg is none of Tensor, numpy array, scalar but a %s\" % (str(type(arg))))\n\t\t\t\tprint(\"Mod: %s\" % str(mod.__name__))\n\t\t\t\tprint(\"Op: %s\" % str(op))\n\t\t\t\tprint(dir(arg))\n\t\t\t'''\n\n\tcadena = {}\n\tcadena['mod'] = mod.__name__\n\tcadena['op'] = op\n\tcadena['args'] = []\n\n\tfoo(args, \"\")\n\tfor k,v in kwargs.items():\n\t\tfoo((v,), k)\n\n\treturn str(cadena)\n\ndef patchClass(cls):\n\tfor f in dir(cls):\n\t\tif isfunc(cls, f):\n\t\t\tadd_wrapper(cls, f)\n\ndef init():\n\tstring = \"\\n\\nPyprof has been moved to its own dedicated repository and will \" + \\\n\t\t\t\"soon be removed from Apex.  Please visit\\n\" + \\\n\t\t\t\"https://github.com/NVIDIA/PyProf\\n\" + \\\n\t\t\t\"for the latest version.\\n\\n\"\n\t# print regardless of warning state\n\tprint(string)\n\n\tprint(\"Initializing NVTX monkey patches\")\n\tfor cls in [torch, torch.Tensor, torch.nn.functional,]:\n\t\tpatchClass(cls)\n\n\tfor cls in [torch.nn.RNN, torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell, torch.nn.GRU, torch.nn.GRUCell]:\n\t\tif isfunc(cls, 'forward'):\n\t\t\tadd_wrapper(cls, 'forward')\n\n\tprint(\"Done with NVTX monkey patching\")\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/parse/__init__.py",
    "content": ""
  },
  {
    "path": "KoSentenceT5/apex/pyprof/parse/__main__.py",
    "content": "import warnings\n\ntry:\n    from .parse import main\nexcept ImportError as e:\n    warnings.warn(\"Did you make sure to install PyProf dependencies by using the --pyprof flag during Apex installation?)\")\n    raise e\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/parse/db.py",
    "content": "import sys, sqlite3\n\nclass DB(object):\n\t\"\"\"\n\tThis class provides functions for DB operations\n\twith exception handling.\n\t\"\"\"\n\n\tdef __init__(self, dbFile):\n\t\ttry:\n\t\t\tconn = sqlite3.connect(dbFile)\n\t\t\tconn.row_factory = sqlite3.Row\n\t\t\tc = conn.cursor()\n\t\texcept:\n\t\t\tprint(\"Error opening {}\".format(dbFile))\n\t\t\tsys.exit(1)\n\n\t\tself.conn = conn\n\t\tself.c = c\n\n\tdef select(self, cmd):\n\t\ttry:\n\t\t\tself.c.execute(cmd)\n\t\t\t#rows = self.c.fetchall()\n\t\t\trows = [dict(row) for row in self.c.fetchall()]\n\t\texcept sqlite3.Error as e:\n\t\t\tprint(e)\n\t\t\tsys.exit(1)\n\t\texcept:\n\t\t\tprint(\"Uncaught error in SQLite access while executing {}\".format(cmd))\n\t\t\tsys.exit(1)\n\n\t\t#print(rows)\n\t\treturn rows\n\n\tdef insert(self, cmd, data):\n\t\ttry:\n\t\t\tself.c.execute(cmd, data)\n\t\texcept sqlite3.Error as e:\n\t\t\tprint(e)\n\t\t\tsys.exit(1)\n\t\texcept:\n\t\t\tprint(\"Uncaught error in SQLite access while executing {}\".format(cmd))\n\t\t\tsys.exit(1)\n\n\tdef execute(self, cmd):\n\t\ttry:\n\t\t\tself.c.execute(cmd)\n\t\texcept sqlite3.Error as e:\n\t\t\tprint(e)\n\t\t\tsys.exit(1)\n\t\texcept:\n\t\t\tprint(\"Uncaught error in SQLite access while executing {}\".format(cmd))\n\t\t\tsys.exit(1)\n\n\tdef commit(self):\n\t\tself.conn.commit()\n\n\tdef close(self):\n\t\tself.c.close()\n\t\tself.conn.close()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/parse/kernel.py",
    "content": "import cxxfilt, struct, binascii\n\n#Helper functions\n\ndef demangle(name):\n\t\"\"\"\n\tDemangle a C++ string\n\t\"\"\"\n\treturn cxxfilt.demangle(name)\n\ndef encode_object_id(pid, tid):\n\t\"\"\"\n\tGiven process id (pid) and thread id (tid), return the object id.\n\tobject id = pid (little endian 4 bytes) + tid (little endian 8 bytes)\n\t\"\"\"\n\tobjId = struct.pack('<i', pid) + struct.pack('<q',tid)\n\tobjId = binascii.hexlify(objId).decode('ascii').upper()\n\treturn objId\n\ndef getShortName(name):\n\t\"\"\"\n\tReturns a shorter kernel name\n\t\"\"\"\n\tsname = name.split(\"<\")[0] \\\n\t\t\t\t.replace(\"void \", \"\") \\\n\t\t\t\t.replace(\"at::\",\"\") \\\n\t\t\t\t.replace(\"cuda::\", \"\") \\\n\t\t\t\t.replace(\"native::\",\"\") \\\n\t\t\t\t.replace(\"(anonymous namespace)::\", \"\")\n\tsname = sname.split(\"(\")[0]\n\treturn sname\n\nclass Kernel(object):\n\t\"\"\"\n\tThis class stores information about a kernel.\n\t\"\"\"\n\n\tkernels = []\n\tprofStart = 0\n\n\tdef __init__(self):\n\t\tself.kNameId = None\n\t\tself.kShortName = None\n\t\tself.kLongName = None\n\t\tself.kStartTime = None\t#GPU start time\n\t\tself.kEndTime = None\t#GPU end time\n\t\tself.kDuration = None\n\t\tself.device = None\n\t\tself.stream = None\n\t\tself.grid = ()\n\t\tself.block = ()\n\t\tself.corrId = None\n\t\tself.rStartTime = None\t#CPU start time\n\t\tself.rEndTime = None\t#CPU end time\n\t\tself.rDuration = None\n\t\tself.tid = None\n\t\tself.pid = None\n\t\tself.objId = None\n\t\tself.timeOffset = None\n\n\t\tself.layerMarkers = []\n\t\tself.traceMarkers = []\n\t\tself.reprMarkers = []\n\t\tself.pyprofMarkers = []\n\t\tself.seqMarkers = []\n\t\tself.otherMarkers = []\n\t\tself.altMarkers = []\n\t\tself.seqId = []\n\t\tself.altSeqId = []\n\t\tself.layer = []\n\n\t\tself.subSeqId = None\n\t\tself.dir = None\n\t\tself.mod = []\n\t\tself.op = []\n\n\tdef setKernelInfo(self, info):\n\t\tself.kNameId = info['name']\n\t\tself.corrId = int(info['correlationId'])\n\t\tstart = int(info['start'])\n\t\tend = int(info['end'])\n\t\tassert end > start, \"This assertion can fail for very large profiles. It usually fails when start = end = 0.\"\n\t\tself.kStartTime = start\n\t\tself.kEndTime = end\n\t\tself.kDuration = end - start\n\t\tassert (start > Kernel.profStart)\n\t\tself.device = int(info['deviceId'])\n\t\tself.stream = int(info['streamId'])\n\t\tself.grid = (info['gridX'], info['gridY'], info['gridZ'])\n\t\tself.block = (info['blockX'], info['blockY'], info['blockZ'])\n\t\tself.timeOffset = Kernel.profStart\n\n\tdef setKernelName(self, name):\n\t\tcadena = demangle(name)\n\t\tself.kLongName = cadena\n\t\tself.kShortName = getShortName(cadena)\n\n\tdef setRunTimeInfo(self, info):\n\t\tstart, end, pid, tid = info\n\t\tself.rStartTime = start\n\t\tself.rEndTime = end\n\t\tself.rDuration = end - start\n\t\tself.pid = pid\n\t\tself.tid = tid\n\t\tself.objId = encode_object_id(pid, tid)\n\n\tdef setMarkerInfo(self, info):\n\t\tself.layerMarkers, self.traceMarkers, self.reprMarkers, self.pyprofMarkers, self.seqMarkers, self.otherMarkers, self.altMarkers, self.seqId, self.altSeqId, self.layer = info\n\t\tself.subSeqId = 0\n\n\tdef setDirection(self):\n\t\t\"\"\"\n\t\tSet direction (fprop, bprop) based on PyTorch sequence markers.\n\t\tIt is a heuristic and not a foolproof method.\n\t\t\"\"\"\n\t\tif\tany(\"Backward, seq = \" in x for x in self.seqMarkers) or \\\n\t\t\tany(\"backward, seq = \" in x for x in self.seqMarkers) or \\\n\t\t\tany(\"Backward0, seq = \" in x for x in self.seqMarkers):\n\t\t\tself.dir = \"bprop\"\n\t\telse:\n\t\t\tself.dir = \"fprop\"\n\n\tdef setOp(self):\n\t\t\"\"\"\n\t\tDetect and set the class/module (mod) and operation (op)\n\t\tof the kernel e.g. torch.nn.functional / linear, torch / sigmoid.\n\t\tThe lookup sequence we use is\n\t\t\tNVTX markers inserted by pyprof\n\t\t\tNVTX markers inserted by PyTorch in bprop\n\t\t\tNVTX markers inserted by PyTorch in fprop\n\t\tIt is a heuristic and not a foolproof method.\n\t\t\"\"\"\n\n\t\tdef sanitize(name):\n\t\t\tname = name.replace(\"torch\",\"\") \\\n\t\t\t\t\t\t.replace(\"autograd\",\"\") \\\n\t\t\t\t\t\t.replace(\"_backward\",\"\") \\\n\t\t\t\t\t\t.replace(\"::\",\"\") \\\n\t\t\t\t\t\t.replace(\"jit\",\"\") \\\n\t\t\t\t\t\t.replace(\"(anonymous namespace)\",\"\")\n\t\t\thead, sep, tail = name.partition(\"Backward\")\n\t\t\treturn head\n\n\t\t#Check pyprof markers\n\t\tfor m in self.pyprofMarkers:\n\t\t\tassert (\"mod\" in m) and (\"op\" in m) and (\"args\" in m)\n\t\t\tt = eval(m)\n\t\t\tself.op.append(t['op'])\n\t\t\tself.mod.append(t['mod'])\n\n\t\tif len(self.op):\n\t\t\treturn\n\n\t\t#Check bprop kernel markers\n\t\tfor m in self.seqMarkers:\n\t\t\tif (\"backward, seq = \" in m) or (\"Backward, seq = \" in m):\n\t\t\t\top = m.split(\",\")[0]\n\t\t\t\top = sanitize(op)\n\t\t\t\tself.op.append(op)\n\t\t\t\tself.mod.append('na')\n\n\t\tif len(self.op):\n\t\t\treturn\n\n\t\t#Check markers with \"seq = \"\n\t\tfor m in self.seqMarkers:\n\t\t\tif \", seq = \" in m:\n\t\t\t\top = m.split(\",\")[0]\n\t\t\t\tself.op.append(op)\n\t\t\t\tself.mod.append('na')\n\n\t\tif len(self.op):\n\t\t\treturn\n\n\t\t#If nothing else\n\t\tif len(self.otherMarkers):\n\t\t\tself.op.append(self.otherMarkers[0])\n\t\tself.mod.append('na')\n\n\tdef print(self):\n\t\t\"\"\"\n\t\tPrint kernel information. This is used by prof.py.\n\t\t\"\"\"\n\n\t\ta = lambda: None\n\t\ta.kShortName = self.kShortName\n\t\ta.kDuration = self.kDuration\n\t\t#a.layerMarkers = self.layerMarkers\n\t\ta.layer = self.layer\n\t\ta.trace = self.traceMarkers\n\t\ta.reprMarkers = self.reprMarkers\n\t\ta.marker = self.pyprofMarkers\n\t\ta.seqMarker = self.seqMarkers\n\n\t\ta.seqId = self.seqId\n\t\ta.subSeqId = self.subSeqId\n\t\ta.altSeqId = self.altSeqId\n\n\t\ta.dir = self.dir\n\t\ta.mod = self.mod\n\t\ta.op = self.op\n\n\t\ta.tid = self.tid\n\t\ta.device = self.device\n\t\ta.stream = self.stream\n\t\ta.grid = self.grid\n\t\ta.block = self.block\n\t\ta.kLongName = self.kLongName\n\n\t\tprint(a.__dict__)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/parse/nvvp.py",
    "content": "import sys\n\nclass NVVP(object):\n\t\"\"\"\n\tThis class gets kernel information from the SQL (nvvp) database.\n\t\"\"\"\n\n\tdriverT = \"CUPTI_ACTIVITY_KIND_DRIVER\"\n\truntimeT = \"CUPTI_ACTIVITY_KIND_RUNTIME\"\n\tkernelT = \"CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL\"\n\tmarkerT = \"CUPTI_ACTIVITY_KIND_MARKER\"\n\tstringT = \"StringTable\"\n\n\tdef __init__(self, db):\n\t\tself.db = db\n\t\tself.markerId = 0\n\n\tdef getProfileStart(self):\n\t\t\"\"\"\n\t\tGet the profile start time\n\t\t\"\"\"\n\t\tprofStart = sys.maxsize\n\t\tfor table in [self.driverT, self.runtimeT, self.kernelT, self.markerT]:\n\t\t\tcolname = \"timestamp\" if table is self.markerT else \"start\"\n\t\t\tcmd = \"select {} from {} ORDER BY {} ASC LIMIT 1\".format(colname, table, colname)\n\t\t\tresult = self.db.select(cmd)\n\t\t\tassert(len(result) <= 1)\n\t\t\tif (len(result) == 1):\n\t\t\t\tassert(colname in result[0])\n\t\t\t\tt = result[0][colname]\n\t\t\t\tif (t < profStart):\n\t\t\t\t\tprofStart = t\n\t\tassert(profStart < sys.maxsize)\n\t\treturn profStart\n\n\tdef getString(self, id_):\n\t\t\"\"\"\n\t\tGet the string associated with an id.\n\t\t\"\"\"\n\t\tcmd = \"select value from {} where _id_ = {}\".format(self.stringT, id_)\n\t\tresult = self.db.select(cmd)\n\t\tassert (len(result) == 1)\n\t\treturn result[0]['value']\n\n\tdef createMarkerTable(self):\n\t\t\"\"\"\n\t\tCreate a temporary table and index it to speed up repeated SQL quesries.\n\t\tThe table is an INNER JOIN of CUPTI_ACTIVITY_KIND_MARKER with itself.\n\t\t\"\"\"\n\t\tcmd = 'CREATE TEMPORARY TABLE marker AS SELECT \\\n\t\t\t\t\ta._id_ as id, \\\n\t\t\t\t\ta.timestamp AS startTime, \\\n\t\t\t\t\tb.timestamp AS endTime, \\\n\t\t\t\t\tHEX(a.objectId) AS objectId, \\\n\t\t\t\t\ta.name AS name \\\n\t\t\t\t\tFROM {} AS a INNER JOIN {} AS b ON \\\n\t\t\t\t\ta.id = b.id and \\\n\t\t\t\t\ta.flags = 2 and b.flags = 4'.format(self.markerT, self.markerT)\n\t\tself.db.execute(cmd)\n\n\t\tself.db.execute('CREATE INDEX start_index ON marker (startTime)')\n\t\tself.db.execute('CREATE INDEX end_index ON marker (endTime)')\n\t\tself.db.execute('CREATE INDEX id_index ON marker (id)')\n\n\tdef getCPUInfo(self, corrId):\n\t\t\"\"\"\n\t\tGiven the correlation id, get CPU start, end, thread id, process id.\n\t\tThe information can be in the runtime table or the driver table.\n\t\t\"\"\"\n\n\t\t#First look in the runtime table\n\t\tcmd = \"select start,end,processId,threadId from {} where correlationId={}\".format(self.runtimeT, corrId);\n\t\tresult = self.db.select(cmd)\n\t\tassert (len(result) <= 1)\n\n\t\tif (len(result) == 0):\n\t\t\t#Look in the driver table\n\t\t\tcmd = \"select start,end,processId,threadId from {} where correlationId={}\".format(self.driverT, corrId);\n\t\t\tresult = self.db.select(cmd)\n\n\t\tassert (len(result) == 1)\n\t\tinfo = result[0]\n\t\tstart = info['start']\n\t\tend = info['end']\n\t\tpid = info['processId']\n\t\ttid = info['threadId']\n\t\ttid = tid & 0xffffffff\t#convert to unsigned\n\t\tassert (end > start)\n\t\treturn [start, end, pid, tid]\n\n\tdef getKernelInfo(self):\n\t\t\"\"\"\n\t\tGet GPU kernel info\n\t\t\"\"\"\n\t\tcmd = \"select name,correlationId,start,end,deviceId,streamId,gridX,gridY,gridZ,blockX,blockY,blockZ from {}\".format(self.kernelT)\n\t\tresult = self.db.select(cmd)\n\t\treturn result\n\n\tdef getMarkerInfo(self, objId, startTime, endTime):\n\t\t\"\"\"\n\t\tThis function first finds all NVTX markers encapsulating\n\t\ta runtime / driver kernel launch.\n\t\tIt then splits the markers into many lists.\n\t\t\tlayerMarkers : User added NVTX markers\n\t\t\ttraceMarkers : Call trace markers (inserted by pyprof)\n\t\t\treprMarkers  : Markers containing the extra_repr() of a module (inserted by pyprof)\n\t\t\tpyprofMarkers: Markers containing args and kwargs (tensor shape, datatype etc.)\n\t\t\tseqMarkers   : Markers containing PyTorch internal sequence markers (inserted by PyTorch)\n\t\t\taltSeqMarkers: Markers inserted by PyTorch between two kernel launches. Needs better explanation.\n\t\t\totherMarkers : Markers not in either of the above categories.\n\n\t\tWe extract seqId from the seq and altSeq markers. The seqId is used in bprop.\n\t\tWe also extract information from the layerMarkers.\n\t\t\"\"\"\n\n\t\tlayerMarkers = []\n\t\ttraceMarkers = []\n\t\treprMarkers = []\n\t\tpyprofMarkers = []\n\t\tseqMarkers = []\n\t\totherMarkers = []\n\t\taltSeqMarkers = []\n\t\tbprop = False\n\n\t\t#Helper functions\n\n\t\tdef delete(objId, sTime):\n\t\t\t\"\"\"\n\t\t\tDelete rows from the temporary SQL table which are no longer required.\n\t\t\tThis speeds up future queries.\n\t\t\t\"\"\"\n\t\t\tmargin = 0\n\t\t\tcmd = 'DELETE FROM marker WHERE objectId = \"{}\" AND endTime < {}'.format(objId, sTime - margin)\n\t\t\t#cmd = 'DELETE FROM marker WHERE endTime < {}'.format(sTime - margin)\n\t\t\tself.db.execute(cmd)\n\n\t\tdef getLayerName(mlist):\n\t\t\t\"\"\"\n\t\t\tGet layer names from layer marker list.\n\t\t\t\"\"\"\n\t\t\tlayers = []\n\t\t\tassert(type(mlist) == list)\n\t\t\tfor m in mlist:\n\t\t\t\tassert(\"layer:\" in m)\n\t\t\t\tl = m.split(\":\")[1]\n\t\t\t\tlayers.append(l)\n\t\t\treturn layers\n\n\t\tdef getSeqId(mlist):\n\t\t\t\"\"\"\n\t\t\tGet sequence ids from seq / alt seq marker list.\n\t\t\t\"\"\"\n\t\t\tids = []\n\t\t\tassert(type(mlist) == list)\n\t\t\tfor m in mlist:\n\t\t\t\tassert(\", seq = \" in m)\n\t\t\t\tseq = int(m.split(\"=\")[1])\n\t\t\t\tids.append(seq)\n\n\t\t\t#Remove duplicates\n\t\t\tids = list(set(ids))\n\t\t\tids.sort()\n\t\t\treturn ids\n\n\t\tdef seqcompare(elem):\n\t\t\t\"\"\"\n\t\t\tSorting function for sequence markers\n\t\t\t\"\"\"\n\t\t\tassert (\", seq = \" in elem)\n\t\t\t#sort by sequence id and then the string\n\t\t\tl = elem.split(\" = \")\n\t\t\treturn l[1] + l[0]\n\n\t\tdef prune(mlist):\n\t\t\t\"\"\"\n\t\t\tRemove markers with the same seqId and if the strings are similar.\n\t\t\tThis function works on a sorted sequence.\n\t\t\t\"\"\"\n\t\t\tassert (type(mlist) == list)\n\t\t\tassert (len(mlist))\n\t\t\ta = mlist[0:1]\n\t\t\tfor i in range(1,len(mlist)):\n\t\t\t\tm = mlist[i]\n\t\t\t\tpm = mlist[i-1]\n\t\t\t\tname,seq = m.split(\",\")\n\t\t\t\tpname,pseq = pm.split(\",\")\n\t\t\t\tsimilar = (name in pname) or (pname in name)\n\t\t\t\tif (seq == pseq) and similar:\n\t\t\t\t\tcontinue\n\t\t\t\telse:\n\t\t\t\t\ta.append(m)\n\t\t\treturn a\n\n\t\tdef filterTrace(mlist):\n\t\t\t\"\"\"\n\t\t\tFilter trace markers to remove certain file names.\n\t\t\t\"\"\"\n\t\t\tassert (type(mlist) == list)\n\t\t\tif len(mlist) == 0:\n\t\t\t\treturn mlist\n\t\t\tmlist = mlist[-1]\t#The last stack trace will be a super set.\n\t\t\tmlist = eval(mlist)\n\t\t\tmlist = mlist['traceMarker']\n\t\t\tassert (type(mlist) == list)\n\t\t\tmlist = list(filter(lambda x : \"/torch/nn/modules/\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/nn/functional.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/tensor.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/autograd/__init__.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/_jit_internal.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/pyprof/nvtx/nvmarker.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/apex/optimizers/\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/_utils.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/optim/\" not in x, mlist))\n\t\t\treturn mlist\n\n\t\t#Find all encapsulating markers\n\t\tcmd = 'SELECT id,name from marker where \\\n\t\t\t\tobjectId = \"{}\" and \\\n\t\t\t\tstartTime < {} and \\\n\t\t\t\tendTime > {} \\\n\t\t\t\tORDER BY startTime ASC'.format(objId, startTime, endTime)\n\t\tresult = self.db.select(cmd)\n\n\t\t#Bin markers into different lists\n\t\tfor r in result:\n\t\t\tm = self.getString(r['name'])\n\n\t\t\t#Hack: If its a known gradient checkpointing marker, ignore it.\n\t\t\tif m.find(\"CheckpointFunctionBackward\") >= 0:\n\t\t\t\tcontinue\n\n\t\t\tif (\"_backward, seq =\" in m) or (\"Backward, seq =\" in m) or (\"Backward0, seq =\" in m):\n\t\t\t\tbprop = True\n\n\t\t\tif (\"mod\" in m) and (\"op\" in m) and (\"args\" in m) and (\"type\" in m):\n\t\t\t\tpyprofMarkers.append(m)\n\t\t\telif (\"layer:\" in m):\n\t\t\t\tlayerMarkers.append(m)\n\t\t\telif (\"traceMarker\" in m):\n\t\t\t\ttraceMarkers.append(m)\n\t\t\telif (\"strRepr\" in m):\n\t\t\t\treprMarkers.append(m)\n\t\t\telif (\", seq = \" in m):\n\t\t\t\tseqMarkers.append(m)\n\t\t\telse:\n\t\t\t\totherMarkers.append(m)\n\n\t\t#Remove duplicates, sort and prune seqMarkers\n\t\tif (len(seqMarkers)):\n\t\t\tseqMarkers = list(set(seqMarkers))\n\t\t\tseqMarkers.sort(key=seqcompare)\n\t\t\tseqMarkers = prune(seqMarkers)\n\n\t\t#Remove duplicates from otherMarkers\n\t\totherMarkers = list(set(otherMarkers))\n\n\t\t#Get markers with seq id (inserted by PyTorch) from the previous kernel to the present kernel\n\t\t#Only for fprop kernels\n\t\tif (len(result) and not bprop):\n\t\t\tloId = self.markerId\n\t\t\thiId = result[-1]['id']\n\t\t\tself.markerId = hiId\n\t\t\t\n\t\t\t#Get markers between loId and hiId\n\t\t\tcmd = 'SELECT id,name from marker where objectId = \"{}\" and id > {} and id < {} ORDER BY startTime ASC'.format(objId, loId, hiId)\n\t\t\tresult1 = self.db.select(cmd)\n\n\t\t\tfor r in result1:\n\t\t\t\tm = self.getString(r['name'])\n\t\t\t\t#Get only markers with seq id\n\t\t\t\tif (\", seq=\" in m):\n\t\t\t\t\taltSeqMarkers.append(m)\n\n\t\t\t#Remove duplicates, sort and prune altSeqMarkers\n\t\t\tif (len(altSeqMarkers)):\n\t\t\t\taltSeqMarkers = list(set(altSeqMarkers))\n\t\t\t\taltSeqMarkers.sort(key=seqcompare)\n\t\t\t\taltSeqMarkers = prune(altSeqMarkers)\n\n\t\tdelete(objId, startTime)\n\n\t\treturn layerMarkers, filterTrace(traceMarkers), reprMarkers, pyprofMarkers, seqMarkers, otherMarkers, altSeqMarkers, getSeqId(seqMarkers), getSeqId(altSeqMarkers), getLayerName(layerMarkers)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/parse/parse.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nParse the SQL db and print a dictionary for every kernel.\n\"\"\"\n\nimport sys\nimport argparse\nfrom tqdm import tqdm\n\nfrom .db import DB\nfrom .kernel import Kernel\nfrom .nvvp import NVVP\n\ndef parseArgs():\n\tparser = argparse.ArgumentParser(prog=sys.argv[0], description=\"Parse SQL (nvvp) db.\")\n\tparser.add_argument(\"file\",\n\t\ttype=str,\n\t\tdefault=None,\n\t\thelp=\"SQL db (nvvp) file.\")\n\n\targs = parser.parse_args()\n\treturn args\n\ndef main():\n\targs = parseArgs()\n\n\tdb = DB(args.file)\n\tnvvp = NVVP(db)\n\n\tkInfo = nvvp.getKernelInfo()\n\tif len(kInfo) == 0:\n\t\tprint(\"Found 0 kernels. Exiting.\", file=sys.stderr)\n\t\tdb.close()\n\t\tsys.exit(0)\n\telse:\n\t\tprint(\"Found {} kernels. Getting info for each kernel.\".format(len(kInfo)), file=sys.stderr)\n\n\tnvvp.createMarkerTable()\n\n\tprevSeqId = -1\n\tprevSubSeqId = -1\n\tprevOp = \"na\"\n\n\tKernel.profStart = nvvp.getProfileStart()\n\n\tfor i in tqdm(range(len(kInfo)), ascii=True):\n\t\tinfo = kInfo[i]\n\t\tk = Kernel()\n\n\t\t#Set kernel info\n\t\tk.setKernelInfo(info)\n\n\t\t#Get, set kernel name\n\t\tname = nvvp.getString(k.kNameId)\n\t\tk.setKernelName(name)\n\n\t\t#Get runtime info\n\t\tinfo = nvvp.getCPUInfo(k.corrId)\n\t\tk.setRunTimeInfo(info)\n\n\t\t#Get and set marker and seqid info\n\t\tinfo = nvvp.getMarkerInfo(k.objId, k.rStartTime, k.rEndTime)\n\t\tk.setMarkerInfo(info)\n\n\t\t#If the seqId contains both 0 and non zero integers, remove 0.\n\t\tif any(seq != 0 for seq in k.seqId) and (0 in k.seqId):\n\t\t\tk.seqId.remove(0)\n\n\t\t#Set direction (it uses seq id)\n\t\tk.setDirection()\n\n\t\t#Set op\n\t\tk.setOp()\n\n\t\t#The following code is based on heuristics.\n\t\t#TODO: Refactor.\n\t\t#Assign subSeqId, adjust seqId and altSeqId\n\t\t#seqId can be 0.\n\t\t#A kernel can have multiple seqIds both in fprop and bprop.\n\t\t#In bprop, seqIds might not decrease monotonically. I have observed a few blips.\n\t\tif len(k.seqId):\n\t\t\tassert (k.dir in [\"fprop\", \"bprop\"])\n\t\t\tif (k.dir == \"fprop\"):\n\t\t\t\t#Check if there is a sequence id larger than the previous\n\t\t\t\tinc = (k.seqId[-1] > prevSeqId)\n\t\t\t\tif inc:\n\t\t\t\t\tcurrSeqId = [x for x in k.seqId if x > prevSeqId][0]\n\t\t\t\telse:\n\t\t\t\t\tcurrSeqId = prevSeqId\n\t\t\telse:\n\t\t\t\tcurrSeqId = k.seqId[0]\n\n\t\t\t#if ((currSeqId == prevSeqId) and (k.op == prevOp)):\n\t\t\tif ((currSeqId == prevSeqId) and (k.op == prevOp)) or ((k.op[0] == \"forward\") and (k.op == prevOp) and (k.mod[0] in [\"LSTMCell\", \"GRUCell\", \"RNNCell\"])):\n\t\t\t\t#The second condition is to trap cases when pytorch does not use cudnn for a LSTMCell.\n\t\t\t\tk.subSeqId = prevSubSeqId + 1\n\n\t\t\tprevSeqId = currSeqId\n\t\t\tprevSubSeqId = k.subSeqId\n\t\t\tprevOp = k.op\n\n\t\t\t#Keep currSeqId in k.seqId, move everything else to k.altSeqId\n\t\t\tfor s in k.seqId:\n\t\t\t\tif s != currSeqId:\n\t\t\t\t\tk.seqId.remove(s)\n\t\t\t\t\tk.altSeqId.append(s)\n\n\t\t\tfor s in k.altSeqId:\n\t\t\t\tif s == currSeqId:\n\t\t\t\t\tk.altSeqId.remove(s)\n\n\t\t\tk.altSeqId = list(set(k.altSeqId))\n\t\t\tif (len(k.altSeqId)):\n\t\t\t\t(k.altSeqId).sort()\n\n\t\tk.print()\n\n\tdb.close()\n\nif __name__ == '__main__':\n\tmain()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/__init__.py",
    "content": "from . import data, prof\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/__main__.py",
    "content": "import warnings\n\ntry:\n    from .prof import main\nexcept ImportError as e:\n    warnings.warn(\"Did you make sure to install PyProf dependencies by using the --pyprof flag during Apex installation?\")\n    raise e\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/activation.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Activation(OperatorLayerBase):\n\t\"\"\"\n\tThis class handles the various activation functions.\n\t\"\"\"\n\n\tops = [\"celu\", \"elu\", \"elu_\", \"hardshrink\", \"hardtanh\", \"hardtanh_\", \"leaky_relu\", \"leaky_relu_\", \"logsigmoid\", \"prelu\", \"relu\", \"relu_\", \"relu6\", \"rrelu\", \"rrelu_\", \"selu\", \"sigmoid\", \"softplus\", \"softshrink\", \"softsign\", \"tanh\", \"tanhshrink\", \"threshold\", \"threshold_\"]\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch.nn.functional\", \"torch\", \"Tensor\"])\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) >= 1)\n\t\targ = args[0]\n\t\tassert (arg['type'] == \"tensor\")\n\n\t\tself.i = arg\n\t\tself.dir = d.dir\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.i['shape']),('type', self.i['dtype'])])\n\t\treturn p\n\n\tdef flops(self):\n\t\tdirection = self.dir\n\t\ttensor = self.i['shape']\n\t\tt = self.i['dtype']\n\n\t\t# TODO: revise\n\t\telems = Utility.numElems(tensor)\n\t\treturn elems\n\n\tdef bytes(self):\n\t\tdirection = self.dir\n\t\ttensor = self.i['shape']\n\t\tt = self.i['dtype']\n\n\t\telems = Utility.numElems(tensor)\n\t\telems = elems * (2 if direction == \"fprop\" else 3)\n\n\t\treturn elems * Utility.typeToBytes(t)\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/base.py",
    "content": "from abc import ABC, abstractmethod\n\nclass OperatorLayerBase(ABC):\n\t\"\"\"\n\tBase class for all layers and operators.\n\tEvery derived class should have the following functions.\n\t\"\"\"\n\n\t@abstractmethod\n\tdef tc(self):\n\t\t\"\"\"\n\t\tTensor core usage by the kernel.\n\t\tReturn \"1\" (yes), \"0\" (no, but possible), \"-\" (not applicable)\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef params(self):\n\t\t\"\"\"\n\t\tKernel parameters to be printed.\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef flops(self):\n\t\t\"\"\"\n\t\tNote that 1 FMA = 2 flops.\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef bytes(self):\n\t\tpass\n\n\t@abstractmethod\n\tdef mod(self):\n\t\t\"\"\"\n\t\tName of the module/class e.g. torch.nn.functional.\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef op(self):\n\t\t\"\"\"\n\t\tName of the operator e.g. sigmoid.\n\t\t\"\"\"\n\t\tpass\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/blas.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\nimport numpy as np\n\nTC_GEMMS = [\"884gemm\", \"1688gemm\"]\n\nclass Addmm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\",])\n\t\tassert (op in [\"addmm\", \"addmm_\",])\n\n\t\t#Get alpha and beta\n\t\talpha = 1\n\t\tbeta = 1\n\t\tif any(x['name'] == 'alpha' for x in args):\n\t\t\talpha = list(filter(lambda x : x['name'] == \"alpha\", args))[0]\n\t\t\talpha = alpha['value']\n\n\t\tif any(x['name'] == 'beta' for x in args):\n\t\t\tbeta = list(filter(lambda x : x['name'] == \"beta\", args))[0]\n\t\t\tbeta = beta['value']\n\n\t\tself.alpha = alpha\n\t\tself.beta = beta\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) == 3)\n\t\tC,A,B = args\n\t\tm,k1 = A['shape']\n\t\tk2,n = B['shape']\n\t\tassert (k1 == k2)\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tt3 = C['dtype']\n\t\tassert(t1 == t2 == t3)\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.C = C\n\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k1\n\t\tself.type = t1\n\t\tself.name = d.name\n\n\t\treturn\n\n\tdef tc(self):\n            for s in TC_GEMMS:\n                if s in self.name:\n                    return 1\n            return 0\n\n\tdef bytes(self):\n\t\tm, n, k = self.m, self.n, self.k\n\t\treturn Utility.typeToBytes(self.type) * (m*n + m*k + n*k)\n\n\tdef flops(self):\n\t\treturn self.m * self.n * self.k * 2\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef params(self):\n\t\tp = OrderedDict([('M',self.n),('N',self.m),('K',self.k),('type',self.type)])\n\t\treturn p\n\nclass Bmm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\") and (op == \"bmm\")\n\n\t\t#Filter out named params (kwargs)\n\t\targs = list(filter(lambda x : x['name'] == \"\", args))\n\n\t\tassert (len(args) == 2)\n\t\tA,B = args\n\t\tb1,m,k1 = A['shape']\n\t\tb2,k2,n = B['shape']\n\t\tassert (b1 == b2)\n\t\tassert (k1 == k2)\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tassert(t1 == t2)\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.b = b1\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k1\n\t\tself.type = t1\n\t\tself.name = d.name\n\n\tdef tc(self):\n            for s in TC_GEMMS:\n                if s in self.name:\n                    return 1\n            return 0\n\n\tdef params(self):\n\t\t#p = OrderedDict([('A', A['shape']), ('B', B['shape']), ('type', t1)])\n\t\tp = OrderedDict([('B',self.b), ('M',self.n),('N',self.m),('K',self.k),('type',self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn self.b * self.m * self.n * self.k * 2\n\n\tdef bytes(self):\n\t\tb, m, n, k = self.b, self.m, self.n, self.k\n\t\treturn Utility.typeToBytes(self.type) * b * (m*n + m*k + n*k)\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\nclass Matmul(OperatorLayerBase):\n\n\tNON_GEMM = [\"kernelPointwiseApply2\", \"reduce_1Block_kernel\", \"elementwise_kernel\"]\n\tNON_TC = NON_GEMM + [\"dot_kernel\"]\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.name = d.name\n\t\tself.sub = d.sub\n\n\t\tassert ((mod == \"torch\") and (op == \"matmul\")) or ((mod == \"Tensor\") and (op == \"__matmul__\"))\n\t\tassert (len(args) == 2)\n\n\t\tassert any([x in d.name for x in Matmul.NON_TC + [\"gemm\", \"gemv\"]])\n\n\t\tA,B = args\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tassert(t1 == t2)\n\n\t\tA = A['shape']\n\t\tB = B['shape']\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.type = t1\n\n\t\t# batch, MNK\n\t\tif (len(A) == 1) and (len(B) == 1):\n\t\t\t#dot product\n\t\t\tassert (A[0] == B[0])\n\t\t\tself.b = (1,)\n\t\t\tself.m = 1\n\t\t\tself.n = 1\n\t\t\tself.k = A[0]\n\n\t\telif (len(A) == 2) and (len(B) == 2):\n\t\t\t#gemm\n\t\t\tm,k1 = A\n\t\t\tk2,n = B\n\t\t\tassert(k1 == k2)\n\t\t\tself.b = (1,)\n\t\t\tself.m = m\n\t\t\tself.n = n\n\t\t\tself.k = k1\n\n\t\telif (len(A) == 1) and (len(B) == 2):\n\t\t\t#vector matrix\n\t\t\tk1 = A[0]\n\t\t\tk2,n = B\n\t\t\tassert(k1 == k2)\n\n\t\t\tself.b = (1,)\n\t\t\tself.m = 1\n\t\t\tself.n = n\n\t\t\tself.k = k1\n\n\t\telif (len(A) == 2) and (len(B) == 1):\n\t\t\t#gemv\n\t\t\tm,k1 = A\n\t\t\tk2 = B[0]\n\t\t\tassert (k1 == k2)\n\n\t\t\tself.b = (1,)\n\t\t\tself.m = m\n\t\t\tself.n = 1\n\t\t\tself.k = k1\n\n\t\telif (len(A) == 1) and (len(B) > 2):\n\t\t\tassert (A[0] == B[-2])\n\n\t\t\tself.b = B[0:-2]\n\t\t\tself.m = 1\n\t\t\tself.n = B[-1]\n\t\t\tself.k = B[-2]\n\n\t\telif (len(B) == 1) and (len(A) > 2):\n\t\t\tassert (B[0] == A[-1])\n\n\t\t\tself.b = A[0:-2]\n\t\t\tself.m = A[-2]\n\t\t\tself.n = 1\n\t\t\tself.k = A[-1]\n\n\t\telse:\n\t\t\tassert (len(A) >= 2)\n\t\t\tassert (len(B) >= 2)\n\t\t\tassert (A[-1] == B[-2])\n\t\t\tself.m = A[-2]\n\t\t\tself.n = B[-1]\n\t\t\tself.k = A[-1]\n\n\t\t\taa = np.empty(A[0:-2])\n\t\t\tbb = np.empty(B[0:-2])\n\t\t\tself.b = np.broadcast(aa, bb).shape\n\n\tdef params(self):\n\t\treturn OrderedDict([('A', self.A), ('B', self.B), ('type', self.type)])\n\n\tdef tc(self):\n\t\tif self.name in Matmul.NON_TC:\n\t\t\treturn \"-\"\n\t\telse:\n                    for s in TC_GEMMS:\n                        if s in self.name:\n                            return 1\n                    return 0\n\n\tdef bytes(self):\n\t\t# TODO: check bytes for non-GEMM cases\n\t\tif self.name in Matmul.NON_GEMM:\n\t\t\treturn 2 * Utility.typeToBytes(self.type) * Utility.numElems(self.A) #could be B as well\n\t\telse:\n\t\t\tm, n, k = self.m, self.n, self.k\n\t\t\treturn Utility.typeToBytes(self.type) * (m*n + m*k + n*k)\n\n\tdef flops(self):\n\t\t# TODO: calculate actual FLOPs. At least we're not saying it's GEMM FLOPs for now.\n\t\tif self.name in Matmul.NON_GEMM:\n\t\t\treturn 0\n\t\telse:\n\t\t\treturn Utility.numElems(self.b) * self.m * self.n * self.k * 2\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\nclass Mm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\") and (op == \"mm\")\n\t\tassert (len(args) == 2)\n\n\t\tA,B = args\n\t\tm,k1 = A['shape']\n\t\tk2,n = B['shape']\n\t\tassert (k1 == k2)\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tassert(t1 == t2)\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k1\n\t\tself.type = t1\n\t\tself.name = d.name\n\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('M',self.n),('N',self.m),('K',self.k),('type',self.type)])\n\t\treturn p\n\n\tdef tc(self):\n            for s in TC_GEMMS:\n                if s in self.name:\n                    return 1\n            return 0\n\n\tdef bytes(self):\n\t\tm, n, k = self.m, self.n, self.k\n\t\treturn Utility.typeToBytes(self.type) * (m*n + m*k + n*k)\n\n\tdef flops(self):\n\t\treturn self.m * self.n * self.k * 2\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/conv.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Conv(OperatorLayerBase):\n\n\t\"\"\"\n\t# N = batch size\n\t# C,H,W = input channels, height, width\n\t# K,P,Q = output channels, height, width\n\t# R,S = filter height, width\n\t# g = groups\n\t\"\"\"\n\n\t#todo: refine winograd and FFT\n\tconvAuxList = [\"nchwToNhwc\", \"nhwcToNchw\", \"OffsetsKernel\",]\n\twinoAuxList = [\"generateWinogradTilesKernel\", \"winogradWgradData\", \"winogradWgradOutput\", \"winogradWgradDelta\"]\n\tfftAuxList = [\"compute_gemm_pointers\", \"flip_filter\", \"fft2d_r2c_\", \"fft2d_c2r_\", \"fft1d_r2c\", \"fft1d_c2r\"]\n\tmiscAuxList = [\"scaleTensor_kernel\",]\n\n\tconvList = [\"_s884cudnn_\", \"_s1688cudnn_\", \"_scudnn_\", \"2d_grouped_direct_kernel\", \"cudnn::detail::implicit_convolve_sgemm\", \"cudnn::detail::dgrad2d_alg1_1\", \"cudnn::detail::wgrad_alg0_engine\", \"cudnn::detail::dgrad_engine\", \"dgrad_1x1_stride_2x2\", \"spatialDepthwiseConvolutionUpdateOutput\"]\n\twinoList = [\"winograd3x3Kernel\", \"_sgemm_\"]\n\tfftList = [\"fermiPlusCgemmLDS128_batched\", \"_gcgemm_\",]\n\tmiscList = []\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.dir = d.dir\n\t\tself.name = d.name\n\t\tself.sub = d.sub\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op in [\"conv1d\", \"conv2d\"])\n\t\tlength = len(args)\n\t\tassert (length >= 2) and (length <= 7)\n\t\ti,w = args[0], args[1]\n\t\tassert (i['type'] == \"tensor\")\n\t\tassert (w['type'] == \"tensor\")\n\n\t\t#ignore bias\n\n\t\tif (length >= 4) and (args[3]['name'] == \"\"):\n\t\t\ts = args[3]\n\t\telif any(x['name'] == 'stride' for x in args):\n\t\t\ts = list(filter(lambda x : x['name'] == 'stride', args))[0]\n\t\telse:\n\t\t\ts = {'name': 'stride', 'type': 'int', 'value': 1}\n\n\t\tif (length >= 5) and (args[4]['name'] == \"\"):\n\t\t\tp = args[4]\n\t\telif any(x['name'] == 'padding' for x in args):\n\t\t\tp = list(filter(lambda x : x['name'] == 'padding', args))[0]\n\t\telse:\n\t\t\tp = {'name': 'padding', 'type': 'int', 'value': 0}\n\n\t\tif (length >= 6) and (args[5]['name'] == \"\"):\n\t\t\td = args[5]\n\t\telif any(x['name'] == 'dilation' for x in args):\n\t\t\td = list(filter(lambda x : x['name'] == 'dilation', args))[0]\n\t\telse:\n\t\t\td = {'name': 'dilation', 'type': 'int', 'value': 1}\n\n\t\tif (length == 7) and (args[6]['name'] == \"\"):\n\t\t\tg = args[6]\n\t\telif any(x['name'] == 'groups' for x in args):\n\t\t\tg = list(filter(lambda x : x['name'] == 'groups', args))[0]\n\t\telse:\n\t\t\tg = {'name': 'groups', 'type': 'int', 'value': 1}\n\n\t\tif op == \"conv1d\":\n\t\t\tassert (len(i['shape']) == 3)\n\t\t\tassert (len(w['shape']) == 3)\n\t\t\tassert (i['dtype'] == w['dtype'])\n\t\t\tN, C1, W = i['shape']\n\t\t\tK, C2, S = w['shape']\n\t\t\tassert (C1 == C2)\n\t\t\tp = p['value'] if Utility.isscalar(p['type']) else p['value'][0]\n\t\t\ts = s['value'] if Utility.isscalar(s['type']) else s['value'][0]\n\t\t\td = d['value'] if Utility.isscalar(d['type']) else d['value'][0]\n\t\t\tg = g['value']\n\t\t\tassert (g == 1)\n\t\t\tH = 1\n\t\t\tR = 1\n\n\t\t\tP = 1 + (H - (((R-1))+1))\n\t\t\tQ = 1 + (W + 2*p - (((S-1)*d)+1))/s\n\t\t\tP = int(P)\n\t\t\tQ = int(Q)\n\t\t\tif (H == 1):\n\t\t\t\tassert (P == 1)\n\t\t\tif (W == 1):\n\t\t\t\tassert (Q == 1)\n\n\t\t\tself.N = N\n\t\t\tself.C = C1\n\t\t\tself.H = H\n\t\t\tself.W = W\n\t\t\tself.K = K\n\t\t\tself.P = P\n\t\t\tself.Q = Q\n\t\t\tself.R = R\n\t\t\tself.S = S\n\t\t\tself.ph = 0\n\t\t\tself.pw = p\n\t\t\tself.U = 1\n\t\t\tself.V = s\n\t\t\tself.dh = 1\n\t\t\tself.dw = d\n\t\t\tself.g = g\n\t\t\tself.type = i['dtype']\n\n\t\telif op == \"conv2d\":\n\t\t\tassert (len(i['shape']) == 4)\n\t\t\tassert (len(w['shape']) == 4)\n\t\t\tassert (i['dtype'] == w['dtype'])\n\t\t\tN, C1, H, W = i['shape']\n\t\t\tK, C2, R, S = w['shape']\n\n\t\t\tif Utility.isscalar(p['type']):\n\t\t\t\tph = pw = p['value']\n\t\t\telse:\n\t\t\t\tassert (p['type'] == \"tuple\")\n\t\t\t\tph, pw = p['value']\n\n\t\t\tif Utility.isscalar(s['type']):\n\t\t\t\tsh = sw = s['value']\n\t\t\telse:\n\t\t\t\tassert (s['type'] == \"tuple\")\n\t\t\t\tsh, sw = s['value']\n\n\t\t\tif Utility.isscalar(d['type']):\n\t\t\t\tdh = dw = d['value']\n\t\t\telse:\n\t\t\t\tassert (d['type'] == \"tuple\")\n\t\t\t\tdh, dw = d['value']\n\n\t\t\tg = g['value']\n\t\t\tassert (g >= 1)\n\t\t\tassert (C1 == C2*g)\n\n\t\t\tP = 1 + (H + 2*ph - (((R-1)*dh)+1))/sh\n\t\t\tQ = 1 + (W + 2*pw - (((S-1)*dw)+1))/sw\n\t\t\tP = int(P)\n\t\t\tQ = int(Q)\n\t\t\tif (H == 1):\n\t\t\t\tassert (P == 1)\n\t\t\tif (W == 1):\n\t\t\t\tassert (Q == 1)\n\n\t\t\tself.N = N\n\t\t\tself.C = C1\n\t\t\tself.H = H\n\t\t\tself.W = W\n\t\t\tself.K = K\n\t\t\tself.P = P\n\t\t\tself.Q = Q\n\t\t\tself.R = R\n\t\t\tself.S = S\n\t\t\tself.ph = ph\n\t\t\tself.pw = pw\n\t\t\tself.U = sh\n\t\t\tself.V = sw\n\t\t\tself.dh = dh\n\t\t\tself.dw = dw\n\t\t\tself.g = g\n\t\t\tself.type = i['dtype']\n\n\t\telse:\n\t\t\tassert False\n\n\tdef params(self):\n\t\tp = OrderedDict([('N',self.N), ('C',self.C), ('H',self.H), ('W',self.W), ('K',self.K), ('P',self.P), ('Q',self.Q), ('R',self.R), ('S',self.S), ('ph',self.ph), ('pw',self.pw), ('U',self.U), ('V',self.V), ('dh',self.dh), ('dw',self.dw), ('g',self.g), ('type',self.type)])\n\t\treturn p\n\n\tdef conv_bytes_flops(self, N, C, H, W, K, P, Q, R, S, g, t):\n\t\tf = 2*N*K*P*Q*C*R*S/g #for fprop\n\t\telems = N*C*H*W + K*C*R*S/g + N*K*P*Q\n\t\tb = elems * Utility.typeToBytes(t)\n\t\treturn b,f\n\n\tdef bytes_flops(self):\n\t\tN,C,H,W,K,P,Q,R,S,ph,pw,U,V,dh,dw,g,t = self.params().values()\n\n\t\tif any(x in self.name for x in Conv.convAuxList+Conv.winoAuxList+Conv.fftAuxList+Conv.miscAuxList):\n\t\t\tbytes, flops = [0, 0]\n\n\t\telif any(x in self.name for x in Conv.convList+Conv.winoList+Conv.fftList+Conv.miscList):\n\t\t\tif g == 1:\n\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t)\n\t\t\telse:\n\t\t\t\tif \"2d_grouped_direct_kernel\" in self.name:\t#only 1 kernel is called\n\t\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t)\n\t\t\t\telif \"spatialDepthwiseConvolutionUpdateOutput\" in self.name: #one kernel for separable conv\n\t\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t)\n\t\t\t\telse:\t#a kernel per group is called\n\t\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C/g,H,W,K/g,P,Q,R,S,1,t)\n\n\t\telif (\"calc_bias_diff\" in self.name):\t#bias gradient\n\t\t\telems = N*K*P*Q\n\t\t\tflops = elems\n\t\t\tbytes = 2 * elems * Utility.typeToBytes(t)\n\t\t\t#params = OrderedDict([('N',N), ('K',K), ('P',P), ('Q',Q), ('type', t)])\n\n\t\telse:\n\t\t\tbytes, flops = [0, 0]\n\n\t\treturn bytes, flops\n\n\tdef bytes(self):\n\t\tb,_ = self.bytes_flops()\n\t\treturn b\n\n\tdef flops(self):\n\t\t_,f = self.bytes_flops()\n\t\treturn f\n\n\tdef tc(self):\n\t\tfor s in [\"884cudnn\", \"1688cudnn\"]:\n\t\t\tif s in self.name:\n\t\t\t\treturn 1\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/convert.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Convert(OperatorLayerBase):\n\t\"\"\"\n\tClass to handle convert operations.\n\t\"\"\"\n\tops = [\"byte\", \"char\", \"double\", \"float\", \"half\", \"int\", \"long\", \"short\", \"to\"]\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op in Convert.ops)\n\t\tassert (len(args) == 1)\n\n\t\t#The argument could be a tensor or scalar\n\t\tt = args[0]\n\t\tif t['type'] == \"tensor\":\n\t\t\tshape = t['shape']\n\t\t\tstype = t['dtype']\n\t\telse:\n\t\t\tshape = (1,)\n\t\t\tstype = t['type']\n\t\tif self.op_ == \"to\":\n\t\t\top = stype\n\n\t\tself.shape = shape\n\t\tself.stype = stype\n\t\tself.dtype = op\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('stype', self.stype), ('dtype', self.dtype)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\tb = self.elems() * (Utility.typeToBytes(self.stype) + Utility.typeToBytes(self.dtype))\n\t\treturn b\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/data.py",
    "content": "from .utility import Utility\n\nclass Data(object):\n\t\"\"\"\n\tClass to store all the data for every kernel e.g. name, bytes, flops, device, stream etc.\n\t\"\"\"\n\tdef __init__(self, kernel):\n\t\t#Available from NVprof\n\t\tself.tid = kernel['tid']\n\t\tself.device = kernel['device']\n\t\tself.stream = kernel['stream']\n\t\tself.grid = str(kernel['grid']).replace(\" \",\"\").replace(\"(\",\"\").replace(\")\",\"\")\n\t\tself.block = str(kernel['block']).replace(\" \",\"\").replace(\"(\",\"\").replace(\")\",\"\")\n\t\tself.name = kernel['kShortName'].replace(\" \",\"_\")\n\t\tself.lName = kernel['kLongName']\n\t\tself.sil = kernel['kDuration']\t#units ns\n\n\t\tself.index = None\n\n\t\t#Markers\n\t\tself.argMarker = kernel['marker']\n\t\tself.modMarker = kernel['reprMarkers']\n\t\tself.seqMarker = kernel['seqMarker']\n\n\t\tself.layer = kernel['layer']\n\t\tself.trace = kernel['trace']\n\n\t\tself.seqId = kernel['seqId']\n\t\tself.altSeqId = kernel['altSeqId']\n\n\t\tself.dir = kernel['dir']\n\t\tself.sub = kernel['subSeqId']\n\n\t\tself.mod = \"na\"\n\t\tself.op = \"na\"\n\t\tself.params = {\"na\":\"na\"}\n\t\tself.tc = \"na\"\n\t\tself.flops = 0\n\t\tself.bytes = 0\n\n\tdef setParams(self, params):\n\t\t#Remove space from params\n\t\tqaz = \"\"\n\t\tfor key,value in params.items():\n\t\t\tif \"type\" not in key:\n\t\t\t\tqaz += \"{}={},\".format(key,value)\n\t\t\telse:\n\t\t\t\tif type(value) is str:\n\t\t\t\t\tqaz += \"{},\".format(Utility.typeToString(value))\n\t\t\t\telse:\n\t\t\t\t\tqaz += \"{}\".format(value)\n\n\t\tself.params = qaz.replace(\" \", \"\")\n\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/dropout.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Dropout(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"dropout\")\n\t\t#assert (len(args) == 1)\n\n\t\tself.shape = args[0]['shape']\n\t\tself.type  = args[0]['dtype']\n\t\tself.dir = d.dir\n\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\t#Ignoring the cost of writing and reading the mask\n\t\treturn Utility.typeToBytes(self.type) * self.elems() * 2\n\n\tdef flops(self):\n\t\t# Note: This is approximate and depends on the RNG\n\t\treturn 5*self.elems()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/embedding.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Embedding(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"embedding\")\n\n\t\tself.ishape = args[0]['shape']\n\t\tself.itype = args[0]['dtype']\n\n\t\tself.eshape = args[1]['shape']\n\t\tself.etype = args[1]['dtype']\n\n\t\tassert (len(self.eshape) == 2)\n\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('I', self.ishape), ('itype', self.itype), ('E', self.eshape), ('etype', self.etype)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef bytes(self):\n\t\tishape = self.ishape\n\t\titype = self.itype\n\t\teshape = self.eshape\n\t\tetype = self.etype\n\n\t\tielems = Utility.numElems(ishape)\n\n\t\tb = 0\n\t\tif self.dir == \"fprop\":\n\t\t\t#indices\n\t\t\tb += ielems * Utility.typeToBytes(itype)\n\t\t\t#read and write the embedding matrix\n\t\t\tb += ielems * eshape[1] * 2 * Utility.typeToBytes(etype)\n\t\telse:\n\t\t\t#3 times the size of the incoming gradient\n\t\t\tb = ielems * eshape[1] * 3 * Utility.typeToBytes(etype)\n\n\t\t\tif self.sub > 0:\n\t\t\t\tb = 0\n\n\t\treturn b\n\n\tdef flops(self):\n\t\t# Note: not implemented yet\n\t\treturn 0\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/index_slice_join_mutate.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nimport numpy as np\nfrom .base import OperatorLayerBase\n\nclass Cat(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\")\n\t\tassert (op == \"cat\")\n\t\tassert (len(args) >= 2)\n\n\t\tt = args[0]['dtype']\n\t\tshapes = []\n\n\t\tfor arg in args:\n\t\t\tif arg['type'] == \"tensor\":\n\t\t\t\tassert (arg['dtype'] == t)\n\t\t\t\tshapes.append(arg['shape'])\n\n\t\tself.type = t\n\t\tself.shapes = shapes\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shapes), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\tb = 0\n\t\tfor s in self.shapes:\n\t\t\tb += Utility.numElems(s)\n\t\treturn 2 * b * Utility.typeToBytes(self.type)\n\nclass Reshape(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"reshape\")\n\n\t\t#Temporarily commenting three lines\n\t\t#assert (len(args) == 2)\n\t\t#t,s = args\n\t\t#assert s['type'] == \"tuple\"\n\n\t\tt = args[0]\n\t\tassert t['type'] == \"tensor\"\n\t\tself.type = t['dtype']\n\t\tself.shape = t['shape']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\treturn 0\n\nclass Gather(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\") or (mod == \"torch\")\n\t\tassert (op == \"gather\")\n\n\t\t#Filter out the \"out\" parameter\n\t\targs = list(filter(lambda x : x['name'] != 'out', args))\n\t\tassert (len(args) == 3)\n\n\t\t#Get input\n\t\tif (args[0]['name'] == \"\"):\n\t\t\targ = args[0]\n\t\telse:\n\t\t\targ = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tassert (arg['type'] == \"tensor\")\n\n\t\tself.shape = arg['shape']\n\t\tself.type = arg['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\treturn 2 * Utility.numElems(self.shape) * Utility.typeToBytes(self.type)\n\nclass MaskedScatter(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"masked_scatter_\")\n\t\tassert (len(args) == 3)\n\n\t\tdst, mask, src = args\n\t\tassert (dst['type'] == mask['type'] == src['type'] == \"tensor\")\n\t\tassert (mask['dtype'] == \"uint8\")\n\t\tassert (dst['dtype'] == src['dtype'])\n\t\tassert (dst['shape'] == mask['shape'])\n\n\t\tself.shape = dst['shape']\n\t\tself.type = dst['dtype']\n\t\tself.seqId = d.seqId\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\telems = Utility.numElems(self.shape)\n\n\t\t#src and dst\n\t\tb = 2 * elems * Utility.typeToBytes(self.type)\n\n\t\t#mask (uint8)\n\t\tb += elems\n\n\t\tif (self.seqId > 0):\n\t\t\tb = 0\n\t\treturn b\n\nclass Nonzero(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"nonzero\")\n\t\tassert (len(args) == 1)\n\n\t\targ = args[0]\n\t\tself.shape = arg['shape']\n\t\tself.type = arg['dtype']\n\t\tself.seqId = d.seqId\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\telems = Utility.numElems(self.shape)\n\t\tdim = len(self.shape)\n\n\t\t#input tensor\n\t\tb = elems * Utility.typeToBytes(self.type)\n\n\t\t#in the worst case, the output is a (elems x dim) tensor of type \"long\"\n\t\tb += elems * dim * Utility.typeToBytes(\"int64\")\n\n\t\tif self.seqId > 0:\n\t\t\treturn 0\n\t\telse:\n\t\t\treturn b\n\nclass IndexSelect(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\") or (mod == \"torch\")\n\t\tassert (op == \"index_select\")\n\n\t\t#Filter out the \"out\" parameter\n\t\targs = list(filter(lambda x : x['name'] != 'out', args))\n\t\tassert (len(args) == 3)\n\n\t\t#Get input, dim and index\n\t\tif (args[0]['name'] == \"\"):\n\t\t\tt = args[0]\n\t\telse:\n\t\t\tt = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tif (args[1]['name'] == \"\"):\n\t\t\td = args[1]\n\t\telse:\n\t\t\td = list(filter(lambda x : x['name'] == \"dim\", args))[0]\n\n\t\tif (args[2]['name'] == \"\"):\n\t\t\ti = args[2]\n\t\telse:\n\t\t\ti = list(filter(lambda x : x['name'] == \"index\", args))[0]\n\n\t\tassert (t['type'] == i['type'] == \"tensor\")\n\t\tassert (d['type'] == \"int\")\n\t\tassert (i['dtype'] == \"int64\")\n\t\tassert (len(i['shape']) == 1)\n\n\t\tshape = t['shape']\n\t\tdim = d['value']\n\t\tindices = i['shape'][0]\n\t\tassert (dim < len(shape))\n\n\t\tself.shape = shape\n\t\tself.dim = dim\n\t\tself.indices = indices\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('D', self.dim),('I', self.indices),('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\t#determine the shape of the output tensor\n\t\tshape = list(self.shape)\n\t\tshape[self.dim] = self.indices\n\n\t\tb = 0\n\n\t\t#time to read the input and write the output\n\t\telems = Utility.numElems(shape)\n\t\tb += 2 * elems * Utility.typeToBytes(self.type)\n\n\t\t#time to read the indices\n\t\tb += self.indices * Utility.typeToBytes(\"int64\")\n\n\t\treturn b\n\nclass MaskedSelect(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\t\tself.sub = d.sub\n\n\t\tassert (mod == \"Tensor\") or (mod == \"torch\")\n\t\tassert (op == \"masked_select\")\n\n\t\t#Filter out the \"out\" parameter\n\t\targs = list(filter(lambda x : x['name'] != 'out', args))\n\t\tassert (len(args) == 2)\n\n\t\t#Get input and mask\n\t\tif (args[0]['name'] == \"\"):\n\t\t\tt = args[0]\n\t\telse:\n\t\t\tt = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tif (args[1]['name'] == \"\"):\n\t\t\tm = args[1]\n\t\telse:\n\t\t\tm = list(filter(lambda x : x['name'] == \"mask\", args))[0]\n\n\t\tassert (m['dtype'] == \"uint8\")\n\n\t\ttensor = t['shape']\n\t\tmask = m['shape']\n\n\t\t#check for broadcast condition\n\t\tif (tensor != mask):\n\t\t\tarray1 = np.empty(list(tensor))\n\t\t\tarray2 = np.empty(list(mask))\n\t\t\ttry:\n\t\t\t\tout = np.broadcast(array1, array2).shape\n\t\t\texcept:\n\t\t\t\tassert False\n\n\t\tself.tshape = tensor\n\t\tself.mshape = mask\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.tshape),('M', self.mshape),('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\ttensor = self.tshape\n\t\tmask = self.mshape\n\t\tt = self.type\n\n\t\t#in the worst case, #output elements = #input elements\n\t\tb = 2 * Utility.numElems(tensor) * Utility.typeToBytes(t)\n\n\t\t#mask tensor (assuming uint8)\n\t\tb += Utility.numElems(mask)\n\t\treturn b\n\n\tdef flops(self):\n\t\treturn 0\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/linear.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Linear(OperatorLayerBase):\n\n\t'''\n\tNotes:\n\tIf the bias occurs before the GEMM, then its 1 write (bias expansion).\n\tIf the bias occurs after, then its 1 read and 1 write.\n\tbias in bprop is a reduction and hence is 1 read.\n\t'''\n\n\tgemmKernels = [\"gemm\", \"gemv\", \"dot_kernel\", \"splitKreduce_kernel\", \"reduce_1Block_kernel\"]\n\tbiasKernels = [\"kernelReduceContigDim\", \"kernelReduceNoncontigDim_shared\", \"elementwise_kernel\", \"reduce_kernel\"]\n\n\tdef setXWBMNK(self, args):\n\t\tx = None\n\t\tw = None\n\t\tb = None\n\t\tif (len(args) == 2):\n\t\t\tx,w = args\n\t\telif (len(args) == 3):\n\t\t\tx,w,b = args\n\t\t\tassert (x['type'] == w['type'] == \"tensor\")\n\t\t\tif (b['type'] == \"tensor\"):\n\t\t\t\tassert(len(b['shape']) == 1)\n\t\t\telif (b['type'] == \"NoneType\"):\n\t\t\t\tassert b['value'] is None\n\t\t\t\tb = None\n\t\t\telse:\n\t\t\t\tassert False\n\t\telse:\n\t\t\tassert False\n\n\t\tassert(len(w['shape']) == 2)\n\t\tk1 = x['shape'][-1]\n\t\tn,k2 = w['shape']\n\t\tassert(k1 == k2)\n\t\tif b is not None:\n\t\t\tassert(b['shape'][0] == n)\n\t\tt1 = x['dtype']\n\t\tt2 = w['dtype']\n\t\tassert(t1 == t2)\n\n\t\t# X, W, B\n\t\tself.x = x['shape']\n\t\tself.w = w['shape']\n\t\tself.b = b['shape'] if b is not None else None\n\t\tself.type = t1\n\n\t\t# M, N, K\n\t\t#n = Utility.numElems(x[0:-1])\n\t\tn = self.x[0:-1]\n\t\tk = self.x[-1]\n\t\tm,k1 = self.w\n\t\tassert (k == k1)\n\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k\n\n\tdef tc(self):\n\t\tif self.op() == \"linear\":\n\t\t\treturn 1 if \"884gemm\" in self.name else 0\n\t\telse:\n\t\t\treturn \"-\"\n\n\tdef __init__(self, d):\n\t\tself.name = d.name\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"linear\")\n\n\t\tself.setXWBMNK(args)\n\n\t\tif any(x in d.name for x in Linear.gemmKernels):\n\t\t\tself.op_ = \"linear\"\n\t\telse:\n\t\t\tassert (d.name in Linear.biasKernels)\n\t\t\tself.op_ = \"bias\"\n\n\t\t'''\n\t\telif ((\"kernelPointwiseApply2\" in d.name) or (\"kernelReduceContigDim\" in d.name) or (\"kernelReduceNoncontigDim_shared\" in d.name)):\n\t\t\t#bias expansion was before the gemm\n\t\t\tself.op_ = \"bias\"\n\n\t\telif (\"elementwise_kernel\" in d.name):\n\t\t\t#Bias addition happens later with a broadcast tensor\n\t\t\tself.op_ = \"bias\"\n\t\t\tassert (len(d.argMarker) == 2)\n\t\t\tmarker = eval(d.argMarker[1])\n\t\t\tmod = marker['mod']\n\t\t\top = marker['op']\n\t\t\targs = marker['args']\n\n\t\t\tassert (mod == \"Tensor\")\n\t\t\tassert (op == \"__iadd__\")\n\t\t\tassert (len(args) == 2)\n\t\t\tmn = args[0]['shape']\n\t\t\tb = args[1]['shape']\n\t\t\tassert (len(b) == 1)\n\n\t\t\tassert (mn == (self.n + (self.m,)))\n\t\t\tassert (b == self.b)\n\n\t\telse:\n\t\t\tassert False\n\t\t'''\n\n\tdef params(self):\n\t\t#p = OrderedDict([('X', self.x), ('W', self.w), ('B', self.b), ('type', self.type)])\n\n\t\tm, n, k, x, w, t = self.m, self.n, self.k, self.x, self.w, self.type\n\t\tif len(n) == 1:\n\t\t\tn = n[0]\n\n\t\tif self.op_ == \"linear\":\n\t\t\tif self.dir == \"fprop\":\n\t\t\t\tp = OrderedDict([('M', m), ('N', n), ('K', k), ('type', t)])\n\t\t\telif self.dir == \"bprop\":\n\t\t\t\tif self.sub == 0:\t\t#dgrad (most likely)\n\t\t\t\t\tp = OrderedDict([('M', k), ('N', n), ('K', m), ('type', t)])\n\t\t\t\telif self.sub == 1:\t#wgrad (most likely)\n\t\t\t\t\tp = OrderedDict([('M', k), ('N', m), ('K', n), ('type', t)])\n\t\t\t\telse:\n\t\t\t\t\t#This happens when there are additional kernels for reduction\n\t\t\t\t\tp = OrderedDict([('X', x), ('W', w), ('type', t)])\n\t\t\telse:\n\t\t\t\tassert False\n\n\t\telif self.op_ == \"bias\":\n\t\t\tp = OrderedDict([('M', m), ('N', n), ('type', t)])\n\t\telse:\n\t\t\tassert False\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef bytesFlops(self):\n\n\t\tm = self.m\n\t\tn = Utility.numElems(self.n)\n\t\tk = self.k\n\n\t\tif self.op_ == \"linear\":\n\t\t\tif self.dir == \"fprop\":\n\t\t\t\tf = m * n * k * 2\n\t\t\t\tb = m*n + m*k + n*k * Utility.typeToBytes(self.type)\n\t\t\telif self.dir == \"bprop\":\n\t\t\t\tif self.sub == 0:\t\t#dgrad (most likely)\n\t\t\t\t\tf = m * n * k * 2\n\t\t\t\t\tb = m*n + m*k + n*k * Utility.typeToBytes(self.type)\n\t\t\t\telif self.sub == 1:\t#wgrad (most likely)\n\t\t\t\t\tf = m * n * k * 2\n\t\t\t\t\tb = m*n + m*k + n*k * Utility.typeToBytes(self.type)\n\t\t\t\telse:\n\t\t\t\t\t#This happens when there are additional kernels for reduction\n\t\t\t\t\tf = 0\n\t\t\t\t\tb = 0\n\t\t\telse:\n\t\t\t\tassert False\n\n\t\telif self.op_ == \"bias\":\n\t\t\tf = m * n\n\t\t\tb = 2 * m * n * Utility.typeToBytes(self.type)\n\t\telse:\n\t\t\tassert False\n\t\treturn b,f\n\n\tdef bytes(self):\n\t\tb, f = self.bytesFlops()\n\t\treturn b\n\n\tdef flops(self):\n\t\tb, f = self.bytesFlops()\n\t\treturn f\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/loss.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\n#TODO: Add support for additional loss functions.\n\nclass MSELoss(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"mse_loss\")\n\t\tassert (len(args) == 3)\n\n\t\t#Get input, target and reduction\n\t\tif (args[0]['name'] == \"\"):\n\t\t\tx = args[0]\n\t\telse:\n\t\t\tx = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tif (args[1]['name'] == \"\"):\n\t\t\ty = args[1]\n\t\telse:\n\t\t\ty = list(filter(lambda x : x['name'] == \"target\", args))[0]\n\n\t\tif (args[2]['name'] == \"\"):\n\t\t\tr = args[2]\n\t\telse:\n\t\t\tr = list(filter(lambda x : x['name'] == \"reduction\", args))[0]\n\n\t\tassert (x['type'] == y['type'] == \"tensor\")\n\t\tassert (x['shape'] == y['shape'])\n\t\tassert (x['dtype'] == y['dtype'])\n\t\tassert (r['type'] == \"str\")\n\t\tassert (r['value'] in [\"none\", \"mean\", \"sum\"])\n\n\t\tself.shape = x['shape']\n\t\tself.type = x['dtype']\n\t\tself.red = r['value']\n\t\tself.dir = d.dir\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type), ('red', self.red)])\n\t\treturn p\n\n\tdef elems(self):\n\t\tred = self.red\n\t\te = Utility.numElems(self.shape)\n\n\t\tif self.dir == \"fprop\":\n\t\t\tif red == \"none\":\n\t\t\t\te *= 3\n\t\t\telse:\n\t\t\t\te *= 2\n\t\telse:\n\t\t\tif red == \"none\":\n\t\t\t\te *= 4\n\t\t\telse:\n\t\t\t\te *= 3\n\t\treturn e\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\n\tdef flops(self):\n\t\treturn self.elems() * 2 + 1\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/misc.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Foo(OperatorLayerBase):\n\t\"\"\"\n\tAn object of Foo is instantiated when we detect an unsupported operator.\n\t\"\"\"\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tshapes = []\n\t\ttypes = []\n\n\t\tfor arg in args:\n\t\t\tif arg['type'] == \"tensor\":\n\t\t\t\tshapes.append(arg['shape'])\n\t\t\t\ttypes.append(arg['dtype'])\n\n\t\tself.shape = shapes\n\t\tself.type = types\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\treturn 0\n\nclass Copy(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"copy_\")\n\t\tassert (len(args) == 2)\n\n\t\tdst, src = args\n\t\tassert (src['type'] == dst['type'])\n\t\tassert (src['shape'] == dst['shape'])\n\n\t\tself.shape = src['shape']\n\t\tself.stype = src['dtype']\n\t\tself.dtype = dst['dtype']\n\n\tdef params(self):\n\t\t#The data type might be different\n\t\tp = OrderedDict([('T', self.shape), ('stype', self.stype), ('dtype', self.dtype)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\treturn self.elems() * (Utility.typeToBytes(self.stype) + Utility.typeToBytes(self.dtype))\n\nclass Clone(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"clone\")\n\t\tassert (len(args) == 1)\n\t\tt = args[0]\n\t\tself.shape = t['shape']\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\treturn 2 * self.elems() * Utility.typeToBytes(self.type)\n\nclass Contiguous(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"contiguous\")\n\t\tassert (len(args) == 1)\n\t\tt = args[0]\n\t\tself.shape = t['shape']\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\treturn 2 * Utility.numElems(self.shape) * Utility.typeToBytes(self.type)\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\nclass Any(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"any\")\n\t\tassert (len(args) == 1)\t#could be 2 as well, the second argument is a bool\n\t\tt = args[0]\n\n\t\tself.shape = t['shape']\n\t\tself.type = t['dtype']\n\t\tself.sub = d.sub\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\treturn Utility.numElems(self.shape) * Utility.typeToBytes(self.type)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/normalization.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass BatchNorm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (op == \"batch_norm\")\n\t\tassert (len(args) == 8)\n\t\ti = args[0]\n\t\tassert (i['type'] == \"tensor\")\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\t\tself.dir = d.dir\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Variance algo-dependent, but this is a reasonable value.\n\t\treturn self.elems() * 8\n\n\tdef bytes(self):\n\t\te = self.elems()\n\t\tif self.dir == \"fprop\":\n\t\t\te *= 4\n\t\telse:\n\t\t\te *= 5\n\n\t\treturn e * Utility.typeToBytes(self.type)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/optim.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\n#TODO: Add support for other optimizers.\n\nclass Adam(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert(op == \"adam\")\n\t\tassert (len(args) == 12) or (len(args) == 14)\n\t\tw, hw, m, v, g = args[0:5]\n\t\tassert (w['shape'] == m['shape'] == v['shape'] == g['shape'])\n\t\tassert (hw['shape'] == w['shape']) or (hw['shape'] == (0,))\t\t#hw could be null\n\t\tassert (w['type'] == m['type'] == v['type'] == g['type'] == hw['type'] == \"tensor\")\n\t\tassert (w['dtype'] == m['dtype'] == v['dtype'] == \"float32\")\n\n\t\tself.w = w\n\t\tself.g = g\n\n\tdef params(self):\n\t\tp = OrderedDict([('T',self.w['shape']), ('wtype',self.w['dtype']), ('gtype',self.g['dtype'])])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\twshape = self.w['shape']\n\t\twtype = self.w['dtype']\n\t\tgtype = self.g['dtype']\n\t\tb = 0\n\n\t\telems = Utility.numElems(wshape)\n\n\t\t#Get time to stream read/write w, m, v\n\t\tb += 6 * elems *  Utility.typeToBytes(wtype)\n\n\t\t#Get time to read \"g\"\n\t\tb += elems * Utility.typeToBytes(gtype)\n\n\t\tif wtype != gtype: #mixed precision\n\t\t\t#Get time to write \"hw\n\t\t\tb += elems * Utility.typeToBytes(gtype)\n\n\t\treturn b\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/output.py",
    "content": "import errno, os, sys\n\nclass Output():\n\t\"\"\"\n\tThis class handles printing of a columed output and a CSV.\n\t\"\"\"\n\n\t# The table below is organized as \n\t# user_option: [output_header, attribute_in_Data_class, type, min_width_in_columed_output]\n\ttable = {\n\t\t\"idx\":\t\t[\"Idx\",\t\t\t\"index\",\tint,\t7],\n\t\t\"seq\":\t\t[\"SeqId\",\t\t\"seqId\",\tstr,\t7],\n\t\t\"altseq\":\t[\"AltSeqId\",\t\"altSeqId\",\tstr,\t7],\n\t\t\"tid\":\t\t[\"TId\",\t\t\t\"tid\",\t\tint,\t12],\n\t\t\"layer\":\t[\"Layer\", \t\t\"layer\",\tstr,\t10],\n\t\t\"trace\":\t[\"Trace\",\t\t\"trace\",\tstr,\t25],\n\t\t\"dir\":\t\t[\"Direction\",\t\"dir\",\t\tstr,\t5],\n\t\t\"sub\":\t\t[\"Sub\",\t\t\t\"sub\",\t\tint,\t3],\n\t\t\"mod\":\t\t[\"Module\",\t\t\"mod\",\t\tstr,\t15],\n\t\t\"op\":\t\t[\"Op\",\t\t\t\"op\",\t\tstr,\t15],\n\t\t\"kernel\":\t[\"Kernel\",\t\t\"name\",\t\tstr,\t0],\n\t\t\"params\":\t[\"Params\",\t\t\"params\",\tstr,\t0],\n\t\t\"sil\":\t\t[\"Sil(ns)\",\t\t\"sil\",\t\tint,\t10],\n\t\t\"tc\":\t\t[\"TC\",\t\t\t\"tc\",\t\tstr,\t2],\n\t\t\"device\":\t[\"Device\",\t\t\"device\",\tint,\t3],\n\t\t\"stream\":\t[\"Stream\",\t\t\"stream\",\tint,\t3],\n\t\t\"grid\":\t\t[\"Grid\",\t\t\"grid\",\t\tstr,\t12],\n\t\t\"block\":\t[\"Block\",\t\t\"block\",\tstr,\t12],\n\t\t\"flops\":\t[\"FLOPs\", \t\t\"flops\",\tint,\t12],\n\t\t\"bytes\":\t[\"Bytes\",\t\t\"bytes\", \tint,\t12]\n\t}\n\n\tdef __init__(self, args):\n\t\tself.cols = args.c\n\t\tself.csv = args.csv\n\t\tself.col = True if (args.w > 0) else False\n\t\tself.width = args.w\n\n\t\tw = 0\n\t\tfor col in self.cols:\n\t\t\tassert col in Output.table.keys()\n\t\t\tw += Output.table[col][3]\n\n\t\tif ((self.col) and (w > self.width)):\n\t\t\tprint(\"Minimum width required to print {} = {}. Exiting.\".format(\",\".join(self.cols), w))\n\t\t\tsys.exit(1)\n\n\t\tremainder = self.width - w\n\n\t\tif (\"kernel\" in self.cols) and (\"params\" in self.cols):\n\t\t\tOutput.table[\"kernel\"][3] = int(remainder/2)\n\t\t\tOutput.table[\"params\"][3] = int(remainder/2)\n\t\telif (\"kernel\" in self.cols):\n\t\t\tOutput.table[\"kernel\"][3] = remainder\n\t\telif (\"params\" in self.cols):\n\t\t\tOutput.table[\"params\"][3] = remainder\n\n\t\t#header format\n\t\tcadena = \"\"\n\t\tfor col in self.cols:\n\t\t\t_,_,t,w = Output.table[col]\n\t\t\tcadena += \"%-{}.{}s \".format(w,w)\n\n\t\tself.hFormat = cadena\n\n\t\t#data format\n\t\tcadena = \"\"\n\t\tfor col in self.cols:\n\t\t\t_,_,t,w = Output.table[col]\n\t\t\tif (t == str):\n\t\t\t\tcadena += \"%-{}.{}s \".format(w,w)\n\t\t\telif (t == int):\n\t\t\t\tcadena += \"%{}d \".format(w)\n\n\t\tself.dFormat = cadena\n\n\tdef foo(self, cadena, pformat):\n\t\tif self.csv:\n\t\t\tcadena = \",\".join(map(lambda x : '\"' + str(x) + '\"', cadena))\n\t\telif self.col:\n\t\t\tcadena = pformat % cadena\n\t\telse:\n\t\t\tcadena = \" \".join(map(str,cadena))\n\n\t\ttry:\n\t\t\tprint(cadena)\n\t\texcept IOError as e:\n\t\t\t#gracefully handle pipes\n\t\t\tif e.errno == errno.EPIPE:\n\t\t\t\t# Python flushes standard streams on exit; redirect remaining output\n\t\t\t\t# to devnull to avoid another BrokenPipeError at shutdown\n\n\t\t\t\tdevnull = os.open(os.devnull, os.O_WRONLY)\n\t\t\t\tos.dup2(devnull, sys.stdout.fileno())\n\t\t\t\tsys.exit(0)\n\t\t\telse:\n\t\t\t\tsys.exit(-1)\n\n\tdef header(self):\n\t\tcadena = ()\n\t\tfor col in self.cols:\n\t\t\th = Output.table[col][0]\n\t\t\tcadena = cadena + (h,)\n\n\t\tself.foo(cadena, self.hFormat)\n\n\tdef data(self, a):\n\t\tif a.dir == \"\":\n\t\t\tdirec = \"na\"\n\t\telse:\n\t\t\tdirec = a.dir\n\n\t\tif a.op == \"\":\n\t\t\top = \"na\"\n\t\telse:\n\t\t\top = a.op\n\n\t\tif a.mod == \"\":\n\t\t\tmod = \"na\"\n\t\telse:\n\t\t\tmod = a.mod\n\n\t\tcadena = ()\n\t\tfor col in self.cols:\n\t\t\tattr = Output.table[col][1]\n\t\t\tval = getattr(a, attr)\n\n\t\t\tif col == \"layer\":\n\t\t\t\tassert(type(val) == list)\n\t\t\t\tval = \":\".join(val)\n\t\t\t\tval = \"-\" if val == \"\" else val\n\n\t\t\tif col == \"trace\":\n\t\t\t\tassert(type(val) == list)\n\t\t\t\tif self.col and len(val):\n\t\t\t\t\tval = val[-1]\n\t\t\t\t\tval = val.split(\"/\")[-1]\n\t\t\t\telse:\n\t\t\t\t\tval = \",\".join(val)\n\t\t\t\t\tval = \"-\" if val == \"\" else val\n\n\t\t\tif col in [\"seq\", \"altseq\"]:\n\t\t\t\tassert(type(val) == list)\n\t\t\t\tval = \",\".join(map(str,val))\n\t\t\t\tval = \"-\" if val == \"\" else val\n\n\t\t\tcadena = cadena + (val,)\n\t\n\t\tself.foo(cadena, self.dFormat)\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/pointwise.py",
    "content": "import numpy as np\nfrom collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Pointwise(OperatorLayerBase):\n\n\tops = []\n\tops += [\"__abs__\", \"__neg__\", \"__invert__\"]\n\tops += [\"__add__\", \"__sub__\", \"__mul__\", \"__floordiv__\", \"__truediv__\", \"__pow__\", \"__mod__\"]\n\tops += [\"__radd__\", \"__rsub__\", \"__rmul__\", \"__rdiv__\", \"__rtruediv__\", \"__rfloordiv__\", \"__rpow__\"]\n\tops += [\"__iadd__\", \"__isub__\", \"__imul__\", \"__itruediv__\",]\n\tops += [\"__lt__\", \"__gt__\", \"__ge__\", \"__le__\", \"__eq__\", \"__ne__\",]\n\tops += [\"lt\", \"lt_\", \"gt\", \"gt_\", \"ge\", \"ge_\", \"le\", \"le_\", \"eq\", \"eq_\", \"ne\", \"ne_\",]\n\tops += [\"__and__\", \"__or__\", \"__xor__\", \"__lshift__\", \"__rshift__\"]\n\tops += [\"__iand__\", \"__ior__\", \"__ixor__\", \"__ilshift__\", \"__irshift__\"]\n\tops += [\"abs\", \"abs_\", \"neg\", \"neg_\"]\n\tops += [\"add\", \"add_\", \"div\", \"div_\", \"mul\", \"mul_\", \"reciprocal\", \"reciprocal_\", \"remainder\", \"remainder_\", \"sub\", \"sub_\",]\n\tops += [\"addcdiv\", \"addcdiv_\", \"addcmul\", \"addcmul_\"]\n\tops += [\"exp\", \"exp_\", \"exp1m\", \"exp1m_\", \"log\", \"log_\", \"log10\", \"log10_\", \"log1p\", \"log1p_\", \"log2\", \"log2_\", \"pow\", \"pow_\", \"rsqrt\", \"rsqrt_\", \"sqrt\", \"sqrt_\",]\n\tops += [\"ceil\", \"ceil_\", \"clamp\", \"clamp_\", \"floor\", \"floor_\", \"fmod\", \"fmod_\", \"frac\", \"frac_\", \"round\", \"round_\", \"sign\", \"sign_\", \"trunc\", \"trunc_\"]\n\tops += [\"acos\", \"acos_\", \"asin\", \"asin_\", \"atan\", \"atan_\", \"atan2\", \"atan2_\", \"cos\", \"cos_\", \"cosh\", \"cosh_\", \"sin\", \"sin_\", \"sinh\", \"sinh_\", \"tan\", \"tan_\", \"sigmoid\", \"sigmoid_\", \"tanh\", \"tanh_\"]\n\tops += [\"digamma\", \"erf\", \"erf_\", \"erfc\", \"erfc_\", \"erfinv\", \"erfinv_\", \"lerp\", \"lerp_\", \"mvlgamma\",]\n\n\t@staticmethod\n\tdef foo(d):\n\t\treturn d['name'],d['type'],d['shape'],d['dtype']\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.dir = d.dir\n\t\tassert (d.dir in [\"fprop\", \"bprop\"])\n\t\tassert (op in Pointwise.ops)\n\n\t\t#Filter out all named parameters (kwargs).\n\t\t#This might require revisiting in future.\n\t\targs = list(filter(lambda x : x['name'] == \"\", args))\n\n\t\t#Filter out non tensors\n\t\targs = list(filter(lambda x : x['type'] == \"tensor\", args))\n\n\t\tif (len(args) == 0):\n\t\t\tself.shape = [(1,)]\n\t\t\tself.type = \"float32\" #FIX\n\n\t\telif (len(args) == 1):\n\t\t\tin0 = args[0]\n\t\t\t_,t0,s0,dt0 = Pointwise.foo(in0)\n\t\t\tassert (t0 == \"tensor\")\n\t\t\tself.shape = [s0,]\n\t\t\tself.type = dt0\n\n\t\telif (len(args) == 2):\n\t\t\tin0,in1 = args\n\t\t\t_,t0,s0,dt0 = Pointwise.foo(in0)\n\t\t\t_,t1,s1,dt1 = Pointwise.foo(in1)\n\t\t\tassert (t0 == t1 == \"tensor\")\n\t\t\tassert (dt0 == dt1)\n\t\t\tself.shape = [s0,s1]\n\t\t\tself.type = dt0\n\n\t\telif (len(args) == 3):\n\t\t\tin0,in1,in2 = args\n\t\t\t_,t0,s0,dt0 = Pointwise.foo(in0)\n\t\t\t_,t1,s1,dt1 = Pointwise.foo(in1)\n\t\t\t_,t2,s2,dt2 = Pointwise.foo(in2)\n\t\t\tassert (t0 == t1 == t2 == \"tensor\")\n\t\t\tassert (dt0 == dt1 == dt2)\n\t\t\tself.shape = [s0,s1,s2]\n\t\t\tself.type = dt0\n\t\telse:\n\t\t\tassert False\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('T',self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\ttensor = self.shape\n\t\tt = self.type\n\n\t\tif (len(tensor) == 1):\n\t\t\telems = 2 * Utility.numElems(tensor[0])\n\t\telif (len(tensor) == 2):\n\t\t\tif (tensor[0] == tensor[1]):\t# same shape\n\t\t\t\telems = Utility.numElems(tensor[0])\n\t\t\t\tif self.dir == \"fprop\":\n\t\t\t\t\telems *= 3\n\t\t\t\telse:\n\t\t\t\t\tif (self.op_ in [\"add\", \"__add__\", \"sub\", \"__sub__\", \"__isub__\"]):\n\t\t\t\t\t\telems *= 2\n\t\t\t\t\telif (self.op_ in [\"__mul__\", \"__rmul__\", \"div\", \"__truediv__\"]):\n\t\t\t\t\t\telems *= 3\n\t\t\t\t\telse:\n\t\t\t\t\t\tassert False\n\t\t\telse:\t#check for broadcast conditions\n\t\t\t\tarray1 = np.empty(list(tensor[0]))\n\t\t\t\tarray2 = np.empty(list(tensor[1]))\n\t\t\t\ttry:\n\t\t\t\t\tout = np.broadcast(array1, array2).shape\n\t\t\t\texcept:\n\t\t\t\t\tassert False\n\n\t\t\t\telems = Utility.numElems(tensor[0])\n\t\t\t\telems += Utility.numElems(tensor[1])\n\t\t\t\telems += Utility.numElems(out)\n\t\t\t\t#TODO bprop\n\t\telif (len(tensor) == 3):\n\t\t\tif (tensor[0] == tensor[1] == tensor[2]):\t#same shape\n\t\t\t\telems = Utility.numElems(tensor[0])\n\t\t\t\telems *= 4\n\t\t\telse:\n\t\t\t\tassert False\n\t\telse:\n\t\t\tassert False\n\n\t\treturn elems\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\n\tdef flops(self):\n\t\t# Note: some cases may still be missing.\n\n\t\tf = 0\n\t\tif self.op_ in [\"__abs__\", \"__neg__\", \"__add__\", \"__sub__\", \"__mul__\",\n\t\t\t\t\t\"__radd__\", \"__rmul__\", \"__iadd__\", \"__isub__\", \"__imul__\", \"__itruediv__\",\n\t\t\t\t\t\"abs\", \"abs_\", \"neg\", \"neg_\", \"add\", \"add_\", \"div\", \"div_\", \"mul\", \"mul_\",\n\t\t\t\t\t\"sub\", \"sub_\", \"exp\", \"exp_\", \"sign\", \"sign_\", \"trunc\", \"trunc_\",\n\t\t\t\t\t\"sin\", \"sin_\", \"cos\", \"cos_\", \"sinh\", \"sinh_\", \"cosh\", \"cosh_\",\n\t\t\t\t\t\"sqrt\", \"sqrt_\", \"rsqrt\", \"rsqrt_\", \"__lt__\", \"__gt__\", \"__ge__\", \"__le__\",\n\t\t\t\t\t\"__eq__\", \"__ne__\", \"lt\", \"lt_\", \"gt\", \"gt_\", \"ge\", \"ge_\", \"le\", \"le_\",\n\t\t\t\t\t\"eq\", \"eq_\", \"ne\", \"ne_\", \"ceil\", \"ceil_\", \"clamp\", \"clamp_\", \"floor\", \"floor_\",\n\t\t\t\t\t\"round\", \"sign\", \"sign_\", \"trunc\", \"trunc_\"]:\n\t\t\t# We're counting only one operand, not two (2 operands, 1 op)\n\t\t\tf = self.elems() / 2\n\t\telif self.op_ in [\"fmod\", \"fmod_\"]:\n\t\t\tf = self.elems()\n\t\telif self.op_ in [\"tanh\", \"tanh_\", \"sigmoid\", \"sigmoid_\", \"log\", \"log_\", \"log2\",\n\t\t\t \"log2_\", \"log10\", \"log10_\"]:\n\t\t\tf = self.elems() * 2\n\t\telif self.op_ in [\"asin\", \"asin_\", \"acos\", \"acos_\", \"atan\", \"atan_\"]:\n\t\t\t# no intrinsic, hence slow execution\n\t\t\t# surprisingly, asin/acos and atan were all the same (via nvprof measurement)\n\t\t\tf = self.elems() * 10\n\n\t\treturn f\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/pooling.py",
    "content": "from .collections import OrderedDict\nfrom .utility import Utility\n\n# Work in progress.\n\n#poolFuncs = [\"max_pool2d_with_indices_forward\", \"max_pool2d_with_indices\"]\nclass MaxPool2d(object):\n\n\tdef parse(marker):\n\n\t\tdef convert2Tuple(arg):\n\t\t\tassert (arg['type'] in [\"int\", \"tuple\"])\n\t\t\tif arg['type'] == \"int\":\n\t\t\t\treturn (arg['value'], arg['value'])\n\t\t\telse:\n\t\t\t\treturn arg['value']\n\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"max_pool2d\")\n\t\tassert (len(args) >= 2)\n\n\t\t#input\n\t\tassert (args[0]['name'] == \"\")\n\t\tinp = args[0]\n\t\tassert (inp['type'] == \"tensor\")\n\t\ti = inp['shape']\n\t\tt = inp['dtype']\n\t\tassert (len(i) == 4) #nchw tensor\n\n\t\t#kernel\n\t\tif (args[1]['name'] == \"\"):\n\t\t\tk = args[1]\n\t\telse:\n\t\t\tk = list(filter(lambda x : x['name'] == \"kernel_size\", args))[0]\n\t\tk = convert2Tuple(k)\n\n\t\t#stride\n\t\ts = k #default value\n\t\tif ((len(args) >= 3) and args[2] == \"\"):\n\t\t\ts = args[2]\n\t\t\ts = convert2Tuple(s)\n\t\telif any(x['name'] == \"stride\" for x in args):\n\t\t\ts = list(filter(lambda x : x['name'] == \"stride\", args))[0]\n\t\t\ts = convert2Tuple(s)\n\n\t\t#padding\n\t\tp = (0,0)\n\t\tif ((len(args) >= 4) and args[3] == \"\"):\n\t\t\tp = args[3]\n\t\t\tp = convert2Tuple(p)\n\t\telif any(x['name'] == \"padding\" for x in args):\n\t\t\tp = list(filter(lambda x : x['name'] == \"padding\", args))[0]\n\t\t\tp = convert2Tuple(p)\n\t\t\n\t\tparams = OrderedDict([('T', i), ('K', k), ('s',s), ('p',p), ('type', t)])\n\t\treturn params\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/prof.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nThis script reads the output (Python dictionary) created by parse.py.\nFor every kernel (line) in the input it determines\n\tmodule / class name e.g. torch.nn.functional\n\toperator name e.g. linear\n\tkernel parameters e.g. GEMM M, N, K, datatype\n\tbytes\n\tflops\n\ttensor core usage\n\tdirection (fprop, bprop)\n\tand other things. Please see the tool usage.\n\"\"\"\n\nfrom .usage import parseArgs\nfrom .output import Output\nfrom .utility import Utility\nfrom .pointwise import Pointwise\nfrom .convert import Convert\nfrom .blas import *\nfrom .embedding import Embedding\nfrom .reduction import *\nfrom .dropout import Dropout\nfrom .softmax import *\n#from pooling import * # work in progress\nfrom .linear import Linear\nfrom .optim import Adam\nfrom .misc import *\nfrom .conv import Conv\nfrom .activation import Activation\nfrom .index_slice_join_mutate import Cat, Reshape, MaskedScatter, Gather, Nonzero, IndexSelect, MaskedSelect\nfrom .recurrentCell import RNNCell\nfrom .normalization import BatchNorm\nfrom .randomSample import RandPerm\nfrom .loss import MSELoss\nfrom .data import Data\n\ndef findFpropKernel(seq):\n\t#Find the last fprop kernel with the same seqId\n\t#First look at seqId and then at altSeqId\n\tfor idx in reversed(range(len(kernels))):\n\t\tk = kernels[idx]\n\t\tif (seq in k['seqId']) and (k['dir'] == \"fprop\"):\n\t\t\treturn idx\n\n\tfor idx in reversed(range(len(kernels))):\n\t\tk = kernels[idx]\n\t\tif (seq in k['altSeqId']) and (k['dir'] == \"fprop\"):\n\t\t\treturn idx\n\n\treturn -1\n\t#print(\"Error: seqId {} not found.\".format(seq), file=sys.stderr)\n\t#assert False\n\ndef foo(mod, op, d):\n\tif (op[0] == \"linear\"):\n\t\txx = Linear(d)\n\n\t# rnncell, lstmcell, grucell\n\telif (mod[0] in[\"LSTMCell\", \"GRUCell\"]) and (op[0] == \"forward\"):\n\t\txx = RNNCell(d)\n\n\telif op[0] in [\"conv1d\", \"conv2d\",]:\n\t\txx = Conv(d)\n\n\telif (op[0] in Pointwise.ops):\n\t\txx = Pointwise(d)\n\n\telif (op[0] in Convert.ops):\n\t\txx = Convert(d)\n\n\telif op[0] in [\"__matmul__\", \"matmul\"]:\n\t\txx = Matmul(d)\n\n\telif op[0] == \"embedding\":\n\t\txx = Embedding(d)\n\n\t#reduction\n\telif op[0] == \"sum\":\n\t\txx = Sum(d)\n\n\telif op[0] == \"mean\":\n\t\txx = Mean(d)\n\n\telif op[0] == \"norm\":\n\t\txx = Norm(d)\n\n\telif op[0] == \"dropout\":\n\t\txx = Dropout(d)\n\n\t#Index, Slice, Join, Mutate\n\telif (op[0] == \"cat\"):\n\t\txx = Cat(d)\n\n\telif (op[0] == \"reshape\"):\n\t\txx = Reshape(d)\n\n\telif (op[0] == \"masked_scatter_\"):\n\t\txx = MaskedScatter(d)\n\n\telif (op[0] == \"gather\"):\n\t\txx = Gather(d)\n\n\telif (op[0] == \"nonzero\"):\n\t\txx = Nonzero(d)\n\n\telif (op[0] == \"index_select\"):\n\t\txx = IndexSelect(d)\n\n\telif (op[0] == \"masked_select\"):\n\t\txx = MaskedSelect(d)\n\n\t#blas\n\telif op[0] in [\"addmm\", \"addmm_\"]:\n\t\txx = Addmm(d)\n\n\telif op[0] == \"mm\":\n\t\txx = Mm(d)\n\n\telif op[0] == \"bmm\":\n\t\txx = Bmm(d)\n\n\t#softmax\n\telif op[0] == \"softmax\":\n\t\txx = Softmax(d)\n\n\telif op[0] == \"log_softmax\":\n\t\txx = LogSoftmax(d)\n\n\t#loss\n\telif op[0] == \"mse_loss\":\n\t\txx = MSELoss(d)\n\n\t#optimizers\n\telif op[0] == \"adam\":\n\t\txx = Adam(d)\n\n\t#normalization\n\telif op[0] == \"batch_norm\":\n\t\txx = BatchNorm(d)\n\n\t#random\n\telif op[0] == \"randperm\":\n\t\txx = RandPerm(d)\n\n\t#misc\n\telif op[0] == \"copy_\":\n\t\txx = Copy(d)\n\n\telif op[0] == \"clone\":\n\t\txx = Clone(d)\n\n\telif op[0] == \"contiguous\":\n\t\txx = Contiguous(d)\n\n\telif op[0] == \"any\":\n\t\txx = Any(d)\n\n\telif (op[0] in Activation.ops):\n\t\txx = Activation(d)\n\n\telif op[0] == \"to\":\n\t\txx = Convert(d)\n\n\telse:\n\t\txx = Foo(d)\n\n\treturn xx\n\ndef main():\n\t#Read cmd line arguments\n\tcmdArgs = parseArgs()\n\n\toutput = Output(cmdArgs)\n\toutput.header()\n\n\tidx = -1\n\t#Read in all the kernel info\n\tfor line in cmdArgs.file:\n\t\tidx += 1\n\t\tkernel = eval(line)\n\t\tassert(kernel)\n\t\tkernels.append(kernel)\n\n\t\tk = kernel\n\t\td = Data(k)\n\n\t\tmod = k['mod']\n\t\top = k['op']\n\n\t\tflops = 0\n\t\tparams = {\"na\":\"na\"}\n\t\ttc = \"na\"\n\t\tbytes = 0\n\n\t\tif (d.dir == \"bprop\"):\n\t\t\td.seqMarker = k['seqMarker']\n\t\t\tseq = k['seqId']\n\t\t\tif len(seq) > 1:\n\t\t\t\tpass\n\t\t\tseq = k['seqId'][:1]\n\t\t\tassert (len(seq) == 1), seq\n\t\t\t#assert (seq[0] != 0)\n\t\t\tassert (len(d.seqMarker) > 0)\n\t\t\t#If there is no useful marker associated, use the\n\t\t\t#sequence number to find the kernel from fprop\n\t\t\tif len(d.argMarker) == 0:\n\t\t\t\tindex = findFpropKernel(seq[0])\n\t\t\t\tif index >= 0:\n\t\t\t\t\td.argMarker = kernels[index]['marker']\n\t\t\t\t\td.modMarker = kernels[index]['reprMarkers']\n\t\t\t\t\tmod = kernels[index]['mod']\n\t\t\t\t\top = kernels[index]['op']\n\n\t\t\t\t\td.layer = kernels[index]['layer']\n\t\t\t\t\td.trace = kernels[index]['trace']\n\n\t\t# Check if marker has our annotations\n\t\tif len(d.argMarker) and Utility.hasNVTX(d.argMarker[0]):\n\n\t\t\txx = foo(mod, op, d)\n\n\t\t\tbytes = xx.bytes()\n\t\t\tflops = xx.flops()\n\t\t\top = xx.op()\n\t\t\tparams = xx.params()\n\t\t\ttc = xx.tc()\n\n\t\tif type(op) is list:\n\t\t\tif len(op):\n\t\t\t\top = op[0]\n\t\t\telse:\n\t\t\t\top = \"\"\n\n\t\tif type(mod) is list:\n\t\t\tif len(mod):\n\t\t\t\tmod = mod[0]\n\t\t\telse:\n\t\t\t\tmod = \"\"\n\n\t\td.index = idx+1\n\n\t\t# The following 8 come from operator class functions.\n\t\td.setParams(params)\n\t\td.tc = tc\n\t\td.flops = flops\n\t\td.bytes = bytes\n\t\td.mod = mod\n\t\td.op = op\n\n\t\toutput.data(d)\n\nkernels = []\nif __name__ == '__main__':\n\tmain()\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/randomSample.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass RandPerm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\")\n\t\tassert (op == \"randperm\")\n\t\tassert (len(args) == 1)\n\t\tn = args[0]\n\t\tassert n['type'] == \"int\"\n\t\tself.n = n['value']\n\n\tdef params(self):\n\t\tp = OrderedDict([('N', self.n)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\treturn self.n * Utility.typeToBytes(\"int64\")\n\n\tdef flops(self):\n\t\t# Depends on RNG but this is probably a reasonable assumption.\n\t\treturn self.n * 3\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/recurrentCell.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\ndef hasTileSize(name):\n\tif (\"sgemm\" in name) or (\"884gemm\" in name) or (\"hgemm\" in name):\n\t\treturn True\n\telse:\n\t\treturn False\n\ndef ctaTile(name):\n\tname = name.split(\"_\")\n\tname = list(filter(lambda x : \"x\" in x, name))\n\tname = list(filter(lambda x : \"slice\" not in x, name))\n\tassert(len(name) == 1)\n\tname = name[0].split(\"x\")\n\tassert(len(name) == 2)\n\tname = list(map(int, name))\n\treturn name[0], name[1]\n\nclass RNNCell(OperatorLayerBase):\n\t\"\"\"\n\tThis class supports RNNCell, LSTMCell and GRUCell.\n\t\"\"\"\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.name = d.name\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\t\tself.grid = d.grid\n\n\t\tassert (op == \"forward\")\n\t\tassert (mod in [\"LSTMCell\", \"GRUCell\", \"RNNCell\"])\n\t\tassert (len(args) in [2,3])\n\n\t\tx,h = args[0],args[1]\n\t\tb1,ii = x['shape']\n\t\tb2,hh = h['shape']\n\t\tassert b1 == b2\n\t\tassert x['dtype'] == h['dtype']\n\t\tt = x['dtype']\n\n\t\tself.cell = mod\n\t\tself.inp = ii\n\t\tself.hid = hh\n\t\tself.b = b1\n\t\tself.type = t\n\n\t\tself.multiple = 1\n\t\tif self.cell == \"LSTMCell\":\n\t\t\tself.multiple = 4\n\t\telif self.cell == \"GRUCell\":\n\t\t\tself.multiple = 3\n\n\t\tself.gemm = None\n\t\tself.m = None\n\t\tself.n = None\n\t\tself.k = None\n\t\tself.elems = 0\n\n\t\tself.bar()\n\t\t\n\tdef params(self):\n\t\tif self.gemm is None:\n\t\t\tp = OrderedDict([('cell', self.cell), ('X', self.inp), ('H', self.hid), ('B', self.b), ('type', self.type)])\n\t\telse:\n\t\t\tassert self.m is not None\n\t\t\tassert self.n is not None\n\t\t\tassert self.k is not None\n\t\t\tp = OrderedDict([('gemm', self.gemm), ('M', self.m), ('N', self.n), ('K', self.k), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\tif \"gemm\" in self.name:\n\t\t\treturn 1 if \"884gemm\" in self.name else 0\n\t\telse:\n\t\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\tif self.gemm is not None:\n\t\t\tm, n, k, t = self.m, self.n, self.k, self.type\n\t\t\tb = (m*k + k*n + m*n) * Utility.typeToBytes(t)\n\t\telif self.elems != 0:\n\t\t\tb = self.elems * Utility.typeToBytes(self.type)\n\t\telse:\n\t\t\tb = 0\n\t\treturn b\n\n\tdef flops(self):\n\t\tif self.gemm is not None:\n\t\t\tm, n, k = self.m, self.n, self.k\n\t\t\tf = 2*m*n*k\n\t\telif self.elems != 0:\n\t\t\tf = 0 #TODO\n\t\telse:\n\t\t\tf = 0\n\t\treturn f\n\n\tdef bar(self):\n\t\tcell = self.cell\n\t\tX = self.inp\n\t\tH = self.hid\n\t\tB = self.b\n\t\tt = self.type\n\t\tsubseqId = self.sub\n\t\tdirec = self.dir\n\t\tname = self.name\n\t\tgrid = self.grid\n\t\tmultiple = self.multiple\n\n\t\tif direc == \"fprop\":\n\t\t\tsubseqId = subseqId % 3\n\t\t\tif subseqId == 0: #layer gemm\n\t\t\t\tself.gemm = \"layer\"\n\t\t\t\tself.m = multiple*H\n\t\t\t\tself.n = B\n\t\t\t\tself.k = X\n\t\t\telif subseqId == 1: #recurrent gemm\n\t\t\t\tself.gemm = \"recur\"\n\t\t\t\tself.m = multiple*H\n\t\t\t\tself.n = B\n\t\t\t\tself.k = H\n\t\t\telse:\n\t\t\t\tlayerGemmElems = multiple*H*B\n\t\t\t\trecurGemmElems = multiple*H*B\n\t\t\t\tcElems = H*B\n\t\t\t\thElems = H*B\n\t\t\t\ttotElems = layerGemmElems + recurGemmElems + 2*cElems + hElems\n\t\t\t\tself.elems = totElems\n\n\t\telse:\n\t\t\tif (\"gemm\" in name) and hasTileSize(name):\t#gemm\n\t\t\t\t#Get cta tile size\n\t\t\t\ttileX, tileY = ctaTile(name)\n\t\t\t\t#Get grid dimensions\n\t\t\t\tgrid = grid.split(\",\")\n\t\t\t\tgridX,gridY,gridZ = map(lambda x : int(x), grid)\n\n\t\t\t\tgemmM = tileX * gridX\n\t\t\t\tgemmN = tileY * gridY\n\n\t\t\t\tif name[-3:] == \"_nn\": # dgrad\n\t\t\t\t\tif (gemmM == H):\t# recurrent dgrad\n\t\t\t\t\t\t#Ideally gemmN = B, but we have a limited set of tile sizes.\n\t\t\t\t\t\tgemmN = B\n\t\t\t\t\t\tgemmK = multiple*H\n\n\t\t\t\t\t\tself.gemm = \"recur\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telif (gemmM == X):\t# layer dgrad\n\t\t\t\t\t\t#assert(gemmN % B == 0)\n\t\t\t\t\t\tgemmK = multiple*H\n\n\t\t\t\t\t\tself.gemm = \"layer\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telse:\n\t\t\t\t\t\tpass\n\n\t\t\t\telif name[-3:] == \"_nt\": #wgrad\n\t\t\t\t\tif (gemmM == H):\t#recurrent wgrad\n\t\t\t\t\t\tassert (gemmN == multiple*H)\n\t\t\t\t\t\tgemmK = B\n\n\t\t\t\t\t\tself.gemm = \"recur\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telif (gemmM == X):\t#layer wgrad\n\t\t\t\t\t\tassert (gemmN == multiple*H)\n\t\t\t\t\t\tgemmK = B\n\n\t\t\t\t\t\tself.gemm = \"layer\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telse:\n\t\t\t\t\t\tpass\n\t\t\t\telse:\n\t\t\t\t\tpass\n\t\t\telse:\n\t\t\t\tpass\n\n\t\treturn\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/reduction.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Mean(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"mean\")\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) <= 2)\n\t\ti = args[0]\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\tif self.sub == 0:\n\t\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\t\telse:\n\t\t\treturn 0\n\n\tdef flops(self):\n\t\tif self.sub == 0:\n\t\t\treturn self.elems() + 1\n\t\telse:\n\t\t\treturn 0\n\nclass Sum(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"sum\")\n\t\tassert (len(args) >= 1)\n\n\t\t#Get input\n\t\tif (args[0]['name'] == \"\"):\n\t\t\ti = args[0]\n\t\telse:\n\t\t\ti = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Note: This is incorrect, need to calculate actual flops (say via nvprof)\n\t\treturn self.elems()\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\nclass Norm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"norm\")\n\t\t#assert (len(args) == 1)\n\t\ti = args[0]\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\n\tdef flops(self):\n\t\t# square and add plus sqrt\n\t\treturn 2 * self.elems() + 1\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/softmax.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Softmax(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"softmax\")\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) <= 2)\n\t\tself.shape = args[0]['shape']\n\t\tself.type = args[0]['dtype']\n\t\tself.dir = d.dir\n\n\t\treturn\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Note: exp, sum-reduce, divide\n\t\t#flops = elems * 3\n\t\treturn 0\n\n\tdef bytes(self):\n\t\tb = self.elems() * Utility.typeToBytes(self.type)\n\t\tb *= 3 if self.dir == \"fprop\" else 5 #verify\n\t\treturn b\n\nclass LogSoftmax(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"log_softmax\")\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) <= 2)\n\n\t\t#Get input\n\t\tif (args[0]['name'] == \"\"):\n\t\t\ti = args[0]\n\t\telse:\n\t\t\ti = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tt = i['dtype']\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\t\tself.dir = d.dir\n\t\treturn\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Note: exp, sum-reduce, divide, log\n\t\t#flops = elems * 4\n\t\treturn 0\n\n\tdef bytes(self):\n\t\tb = self.elems() * Utility.typeToBytes(self.type)\n\t\tb *= 3 if self.dir == \"fprop\" else 5 #verify\n\t\treturn b\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/usage.py",
    "content": "import sys\nimport argparse\n\ndef parseArgs():\n\t\"\"\"\n\tPrint usage and parse arguments.\n\t\"\"\"\n\n\tdef check_cols(value):\n\t\tvalid = [\"idx\", \"seq\", \"altseq\", \"tid\", \"layer\", \"trace\", \"dir\", \"sub\", \"mod\", \"op\", \"kernel\", \"params\", \"sil\", \"tc\", \"device\", \"stream\", \"grid\", \"block\", \"flops\", \"bytes\"]\n\t\tcols = value.split(\",\")\n\t\tfor col in cols:\n\t\t\tif col not in valid:\n\t\t\t\traise argparse.ArgumentTypeError(\"{} is not a valid column name. Valid column names are {}.\".format(col, \",\".join(valid)))\n\t\treturn cols\n\n\tdef openFile(f):\n\t\ttry:\n\t\t\td = open(f, \"r\")\n\t\t\treturn d\n\t\texcept IOError:\n\t\t\tprint(\"Error opening file {}. Exiting.\".format(f), file=sys.stderr)\n\t\t\tsys.exit(1)\n\n\tparser = argparse.ArgumentParser(prog=sys.argv[0], description=\"PyTorch Profiler\", formatter_class=argparse.RawTextHelpFormatter)\n\tparser.add_argument(\"file\",\n\t\tnargs='?',\n\t\ttype=str,\n\t\tdefault=None,\n\t\thelp=\"Output of parse.py (Python dictionary).\")\n\n\tparser.add_argument(\"-c\",\n\t\ttype=check_cols,\n\t\tdefault=\"idx,dir,sub,mod,op,kernel,params,sil\",\n\t\thelp='''Comma seperated names of columns to print.\nidx:      Index\nseq:      PyTorch Sequence Id\naltseq:   PyTorch Alternate Sequence Id\ntid:      Thread Id\nlayer:    User annotated NVTX string (can be nested)\ntrace:    Function Call Trace\ndir:      Direction\nsub:      Sub Sequence Id\nmod:      Module\nop:       Operattion\nkernel:   Kernel Name\nparams:   Parameters\nsil:      Silicon Time (in ns)\ntc:       Tensor Core Usage\ndevice:   GPU Device Id\nstream:   Stream Id\ngrid:     Grid Dimensions\nblock:    Block Dimensions\nflops:    Floating point ops (FMA = 2 FLOPs)\nbytes:    Number of bytes in and out of DRAM\ne.g. -c idx,kernel,sil''')\n\n\tgroup = parser.add_mutually_exclusive_group()\n\tgroup.add_argument(\"--csv\",\n\t\taction=\"store_true\",\n\t\tdefault=False,\n\t\thelp=\"Print a CSV output.\")\n\tgroup.add_argument(\"-w\",\n\t\ttype=int,\n\t\tdefault=0,\n\t\thelp=\"Width of columnated output.\")\n\n\targs = parser.parse_args()\n\tif args.file is None:\n\t\targs.file = sys.stdin\n\telse:\n\t\targs.file = openFile(args.file)\n\treturn args\n"
  },
  {
    "path": "KoSentenceT5/apex/pyprof/prof/utility.py",
    "content": "from functools import reduce\n\nclass Utility(object):\n\n\t@staticmethod\n\tdef numElems(shape):\n\t\tassert (type(shape) == tuple)\n\t\treturn reduce(lambda x,y: x*y, shape, 1)\n\n\t@staticmethod\n\tdef typeToBytes(t):\n\t\tif (t in [\"uint8\", \"int8\", \"byte\", \"char\", \"bool\"]):\n\t\t\treturn 1\n\t\telif (t in [\"float16\", \"half\", \"int16\", \"short\"]):\n\t\t\treturn 2\n\t\telif (t in [\"float32\", \"float\", \"int32\", \"int\"]):\n\t\t\treturn 4\n\t\telif (t in [\"int64\", \"long\", \"float64\", \"double\"]):\n\t\t\treturn 8\n\t\tassert False\n\n\t@staticmethod\n\tdef typeToString(t):\n\t\tif (t in [\"uint8\", \"byte\", \"char\",]):\n\t\t\treturn \"uint8\"\n\t\telif (t in [\"int8\",]):\n\t\t\treturn \"int8\"\n\t\telif (t in [\"int16\", \"short\",]):\n\t\t\treturn \"int16\"\n\t\telif (t in [\"float16\", \"half\"]):\n\t\t\treturn \"fp16\"\n\t\telif (t in [\"float32\", \"float\"]):\n\t\t\treturn \"fp32\"\n\t\telif (t in [\"int32\", \"int\",]):\n\t\t\treturn \"int32\"\n\t\telif (t in [\"int64\", \"long\"]):\n\t\t\treturn \"int64\"\n\t\telif (t in [\"float64\", \"double\",]):\n\t\t\treturn \"fp64\"\n\t\telif (t in [\"bool\",]):\n\t\t\treturn \"bool\"\n\t\tassert False\n\n\t@staticmethod\n\tdef hasNVTX(marker):\n\t\tif type(marker) is str:\n\t\t\ttry:\n\t\t\t\tmarker = eval(marker)\n\t\t\texcept:\n\t\t\t\treturn False\n\n\t\tif type(marker) is dict:\n\t\t\tkeys  = marker.keys()\n\t\t\treturn (\"mod\" in keys) and (\"op\" in keys) and (\"args\" in keys)\n\t\telse:\n\t\t\treturn False\n\n\t@staticmethod\n\tdef isscalar(t):\n\t\treturn (t in [\"float\", \"int\"])\n"
  },
  {
    "path": "KoSentenceT5/apex/reparameterization/README.md",
    "content": "Under construction...\n"
  },
  {
    "path": "KoSentenceT5/apex/reparameterization/__init__.py",
    "content": "from .weight_norm import WeightNorm\nfrom .reparameterization import Reparameterization\n\ndef apply_weight_norm(module, name='', dim=0, hook_child=True):\n    r\"\"\"\n    Applies weight normalization to a parameter in the given module.\n    If no parameter is provided, applies weight normalization to all\n    parameters in model (except 1-d vectors and scalars).\n\n    .. math::\n         \\mathbf{w} = g \\dfrac{\\mathbf{v}}{\\|\\mathbf{v}\\|}\n\n    Weight normalization is a reparameterization that decouples the magnitude\n    of a weight tensor from its direction. This replaces the parameter specified\n    by `name` (e.g. \"weight\") with two parameters: one specifying the magnitude\n    (e.g. \"weight_g\") and one specifying the direction (e.g. \"weight_v\").\n    Weight normalization is implemented via a hook that recomputes the weight\n    tensor from the magnitude and direction before every :meth:`~Module.forward`\n    call.\n\n    By default, with `dim=0`, the norm is computed independently per output\n    channel/plane. To compute a norm over the entire weight tensor, use\n    `dim=None`.\n\n    See https://arxiv.org/abs/1602.07868\n\n    Args:\n        module (nn.Module): containing module\n        name (str, optional): name of weight parameter\n        dim (int, optional): dimension over which to compute the norm\n        hook_child (boolean, optional): adds reparameterization hook to direct parent of the \n            parameters. If False, it's added to `module` instead. Default: True\n\n    Returns:\n        The original module with the weight norm hook\n\n    Example::\n\n        >>> m = apply_weight_norm(nn.Linear(20, 40), name='weight')\n        Linear (20 -> 40)\n        >>> m.weight_g.size()\n        torch.Size([40, 1])\n        >>> m.weight_v.size()\n        torch.Size([40, 20])\n\n    \"\"\"\n    return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child,\n                                    name=name, dim=dim)\n\ndef remove_weight_norm(module, name='', remove_all=False):\n    \"\"\"\n    Removes the weight normalization reparameterization of a parameter from a module.\n    If no parameter is supplied then all weight norm parameterizations are removed.\n    Args:\n        module (nn.Module): containing module\n        name (str, optional): name of weight parameter\n    Example:\n        >>> m = apply_weight_norm(nn.Linear(20, 40))\n        >>> remove_weight_norm(m)\n    \"\"\"\n    return remove_reparameterization(module, reparameterization=WeightNorm,\n                                    name=name, remove_all=remove_all)\n\ndef apply_reparameterization(module, reparameterization=None, name='', dim=0, hook_child=True):\n    \"\"\"\n    Applies a given weight reparameterization (such as weight normalization) to\n    a parameter in the given module. If no parameter is given, applies the reparameterization\n    to all parameters in model (except 1-d vectors and scalars).\n\n    Args:\n        module (nn.Module): containing module\n        reparameterization (Reparameterization): reparamaterization class to apply\n        name (str, optional): name of weight parameter\n        dim (int, optional): dimension over which to perform reparameterization op\n        hook_child (boolean, optional): adds reparameterization hook to direct parent of the \n            parameters. If False, it's added to `module` instead. Default: True\n\n    Returns:\n        The original module with the reparameterization hook\n\n    Example::\n\n        >>> m = apply_reparameterization(nn.Linear(20, 40), WeightNorm)\n        Linear (20 -> 40)\n\n    \"\"\"\n    assert reparameterization is not None\n    if name != '':\n        Reparameterization.apply(module, name, dim, reparameterization, hook_child)\n    else:\n        names = list(module.state_dict().keys())\n        for name in names:\n            apply_reparameterization(module, reparameterization, name, dim, hook_child)\n    return module\n\ndef remove_reparameterization(module, reparameterization=Reparameterization,\n                                name='', remove_all=False):\n    \"\"\"\n    Removes the given reparameterization of a parameter from a module.\n    If no parameter is supplied then all reparameterizations are removed.\n    Args:\n        module (nn.Module): containing module\n        reparameterization (Reparameterization): reparamaterization class to apply\n        name (str, optional): name of weight parameter\n        remove_all (bool, optional): if True, remove all reparamaterizations of given type. Default: False\n    Example:\n        >>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)\n        >>> remove_reparameterization(m)\n    \"\"\"\n    if name != '' or remove_all:\n        to_remove = []\n        for k, hook in module._forward_pre_hooks.items():\n            if isinstance(hook, reparameterization) and (hook.name == name or remove_all):\n                hook.remove(module)\n                to_remove.append(k)\n        if len(to_remove) > 0:\n            for k in to_remove:\n                del module._forward_pre_hooks[k]\n            return module\n        if not remove_all:\n            raise ValueError(\"reparameterization of '{}' not found in {}\"\n                             .format(name, module))\n    else:\n        modules = [module]+[x for x in module.modules()]\n        for m in modules:\n            remove_reparameterization(m, reparameterization=reparameterization, remove_all=True)\n        return module\n"
  },
  {
    "path": "KoSentenceT5/apex/reparameterization/reparameterization.py",
    "content": "import torch\nfrom torch.nn.parameter import Parameter\nimport sys\nclass Reparameterization(object):\n    \"\"\"\n    Class interface for performing weight reparameterizations\n    Arguments:\n        name (str): name of weight parameter\n        dim (int): dimension over which to compute the norm\n        module (nn.Module): parent module to which param `name` is registered to\n        retain_forward (bool, optional): if False deletes weight on call to \n            module.backward. Used to avoid memory leaks with DataParallel Default: True\n    Attributes:\n        reparameterization_names (list, str): contains names of all parameters \n            needed to compute reparameterization.\n        backward_hook_key (int): torch.utils.hooks.RemovableHandle.id for hook used in module backward pass.\n    \"\"\"\n\n    def __init__(self, name, dim, module, retain_forward=True):\n        self.name = name\n        self.dim = dim\n        self.evaluated = False\n        self.retain_forward = retain_forward\n        self.reparameterization_names = []\n        self.backward_hook_key = None\n        self.module = module\n\n    def compute_weight(self, module=None, name=None):\n        \"\"\"\n        Computes reparameterized weight value to assign value to module attribute\n        with name `name`.\n        See WeightNorm class for example.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n        Returns:\n            w (Tensor): Tensor object containing value of reparameterized weight\n        \"\"\"\n        raise NotImplementedError\n\n    def reparameterize(self, name, weight, dim):\n        \"\"\"\n        Creates Parameters to be used for reparameterization and creates names that\n        for attributes for the module these Parameters will correspond to.\n        The parameters will be registered according to the names provided.\n        See WeightNorm class for example.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n            name (str, optional): name of weight parameter\n            dim (int, optional): dimension over which to compute parameterization\n        Returns:\n            names (list, str): names of Parameters to be used for reparameterization\n            params (list, Parameter): Parameters to be used for reparameterization\n        \"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def apply(module, name, dim, reparameterization=None, hook_child=True):\n        \"\"\"\n        Applies reparametrization to module's `name` parameter and modifies instance attributes as appropriate.\n        `hook_child` adds reparameterization hook to direct parent of the parameters. If False, it's added to `module` instead.\n        \"\"\"\n        if reparameterization is None:\n            reparameterization = Reparameterization\n        module2use, name2use = Reparameterization.get_module_and_name(module, name)\n        # does not work on sparse\n        if name2use is None or isinstance(module2use, (torch.nn.Embedding, torch.nn.EmbeddingBag)):\n            return\n\n        if hook_child:\n            fn = reparameterization(name2use, dim, module2use)\n        else:\n            fn = reparameterization(name, dim, module)\n\n        weight = getattr(module2use, name2use)\n        if weight.dim() <= 1:\n            return\n\n        # remove weight from parameter list\n        del module2use._parameters[name2use]\n\n        # add parameters of reparameterization of parameter to module\n        names, params = fn.reparameterize(name2use, weight, dim)\n        for n, p in zip(names, params):\n            module2use.register_parameter(n, p)\n\n        # add parameters to reparameterization so they can be removed later\n        fn.reparameterization_names = names\n\n        setattr(module2use, name2use, None)\n\n        hook_module = module2use\n        if not hook_child:\n            hook_module = module\n        # recompute weight before every forward()\n        hook_module.register_forward_pre_hook(fn)\n\n        # remove weight during backward\n        handle = hook_module.register_backward_hook(fn.backward_hook)\n        # get hook key so we can delete it later\n        fn.backward_hook_key = handle.id\n\n        return fn\n\n    @staticmethod\n    def get_module_and_name(module, name):\n        \"\"\"\n        recursively fetches (possible) child module and name of weight to be reparameterized\n        \"\"\"\n        name2use = None\n        module2use = None\n        names = name.split('.')\n        if len(names) == 1 and names[0] != '':\n            name2use = names[0]\n            module2use = module\n        elif len(names) > 1:\n            module2use = module\n            name2use = names[0]\n            for i in range(len(names)-1):\n                module2use = getattr(module2use, name2use)\n                name2use = names[i+1]\n        return module2use, name2use\n\n    def get_params(self, module):\n        \"\"\"gets params of reparameterization based on known attribute names\"\"\"\n        return [getattr(module, n) for n in self.reparameterization_names]\n\n    def remove(self, module):\n        \"\"\"removes reparameterization and backward hook (does not remove forward hook)\"\"\"\n        module2use, name2use = Reparameterization.get_module_and_name(module, self.name)\n        for p in self.get_params(module2use):\n            p.requires_grad = False\n        weight = self.compute_weight(module2use, name2use)\n        delattr(module2use, name2use)\n        for n in self.reparameterization_names:\n            del module2use._parameters[n]\n        module2use.register_parameter(name2use, Parameter(weight.data))\n        del module._backward_hooks[self.backward_hook_key]\n\n    def __call__(self, module, inputs):\n        \"\"\"callable hook for forward pass\"\"\"\n        module2use, name2use = Reparameterization.get_module_and_name(module, self.name)\n        _w = getattr(module2use, name2use)\n        if not self.evaluated or _w is None:\n            setattr(module2use, name2use, self.compute_weight(module2use, name2use))\n            self.evaluated = True\n\n    def backward_hook(self, module, grad_input, grad_output):\n        \"\"\"callable hook for backward pass\"\"\"\n        module2use, name2use = Reparameterization.get_module_and_name(module, self.name)\n        wn = getattr(module2use, name2use)\n        self.evaluated = False\n"
  },
  {
    "path": "KoSentenceT5/apex/reparameterization/weight_norm.py",
    "content": "import torch\nfrom torch.nn.parameter import Parameter\nfrom ..fp16_utils import Fused_Weight_Norm\nimport time\n\nfrom .reparameterization import Reparameterization\n\ndef _norm(p, dim):\n    \"\"\"Computes the norm over all dimensions except dim\"\"\"\n    if dim is None:\n        return p.norm()\n    elif dim == 0:\n        output_size = (p.size(0),) + (1,) * (p.dim() - 1)\n        return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)\n    elif dim == p.dim() - 1:\n        output_size = (1,) * (p.dim() - 1) + (p.size(-1),)\n        return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)\n    return _norm(p.transpose(0, dim), 0).transpose(0, dim)\n\nHALF_TYPES = (torch.cuda.HalfTensor, torch.HalfTensor)\n\nclass WeightNorm(Reparameterization):\n    r\"\"\"\n    Weight normalization is a reparameterization that decouples the magnitude\n    of a weight tensor from its direction. This replaces the parameter specified\n    by `name` (e.g. \"weight\") with two parameters: one specifying the magnitude\n    (e.g. \"weight_g\") and one specifying the direction (e.g. \"weight_v\").\n    Weight normalization is implemented via a hook that recomputes the weight\n    tensor from the magnitude and direction before every :meth:`~Module.forward`\n    call.\n\n    .. math::\n         \\mathbf{w} = g \\dfrac{\\mathbf{v}}{\\|\\mathbf{v}\\|}\n\n    By default, with `dim=0`, the norm is computed independently per output\n    channel/plane. To compute a norm over the entire weight tensor, use\n    `dim=None`.\n    \"\"\"\n    def compute_weight(self, module=None, name=None):\n        \"\"\"\n        Computes weight normalized weight value to assign value to module attribute\n        with name `name`.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n        Returns:\n            w (Tensor): Tensor object containing value of reparameterized weight\n        \"\"\"\n        if module is None:\n            module = self.module\n        if name is None:\n            name = self.name\n        module, name = Reparameterization.get_module_and_name(module, name)\n        g = getattr(module, name + '_g')\n        v = getattr(module, name + '_v')\n\n        fused_weight_norm = Fused_Weight_Norm.apply\n        v = v.contiguous()\n        w = fused_weight_norm(v, g, self.dim)\n\n        return w\n\n    def reparameterize(self, name, weight, dim):\n        \"\"\"\n        Creates Parameters v and gto be used for weight normalization\n        and creates names that for attributes for the module these Parameters\n        will correspond to. The parameters will be registered according to the names\n        provided.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n            name (str, optional): name of weight parameter\n            dim (int, optional): dimension over which to compute parameterization\n        Returns:\n            names (list, str): names of Parameters to be used for reparameterization\n            params (list, Parameter): Parameters to be used for reparameterization\n        \"\"\"\n        names = [name + '_g', name + '_v']\n        params = [Parameter(_norm(weight, dim).data), Parameter(weight.data)]\n        return names, params\n"
  },
  {
    "path": "KoSentenceT5/data/dataloader.py",
    "content": "import numpy\nimport torch\nimport logging\nfrom torch.utils.data import DataLoader, Dataset\nfrom transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast\n\nlogger = logging.getLogger(__name__)\n\n\nclass ModelDataLoader(Dataset):\n    def __init__(self, file_path, args, metric, tokenizer, type_):\n        self.type = type_\n        self.args = args\n        self.metric = metric\n\n        \"\"\"NLI\"\"\"\n        self.anchor = []\n        self.anchor_dec = []\n\n        self.positive = []\n        self.positive_dec = []\n\n        self.negative = []\n        self.negative_dec = []\n\n        \"\"\"STS\"\"\"\n        self.label = []\n        \n        self.sentence_1 = []\n        self.sentence_1_dec = []\n\n        self.sentence_2 = []\n        self.sentence_2_dec = []\n\n        #  -------------------------------------\n        self.bert_tokenizer = tokenizer\n        self.file_path = file_path\n\n        special_tokens = {'bos_token': \"[CLS]\"}\n        self.bert_tokenizer.add_special_tokens(special_tokens)\n\n        self.init_token = self.bert_tokenizer.bos_token\n        self.pad_token = self.bert_tokenizer.pad_token\n        self.unk_token = self.bert_tokenizer.unk_token\n        self.eos_token = self.bert_tokenizer.eos_token\n\n        self.init_token_idx = self.bert_tokenizer.convert_tokens_to_ids(self.init_token)\n        self.pad_token_idx = self.bert_tokenizer.convert_tokens_to_ids(self.pad_token)\n        self.unk_token_idx = self.bert_tokenizer.convert_tokens_to_ids(self.unk_token)\n        self.eos_token_idx = self.bert_tokenizer.convert_tokens_to_ids(self.eos_token)\n        \n        print(self.init_token, self.init_token_idx)\n        print(self.pad_token, self.pad_token_idx)\n        print(self.unk_token, self.unk_token_idx)\n        print(self.eos_token, self.eos_token_idx)\n        \n    def load_data(self, type):\n\n        with open(self.file_path) as file:\n            lines = file.readlines()\n\n            for line in lines:\n                _ = self.data2tensor(line, type)\n                \n        if type == 'train':\n            assert len(self.anchor) == len(self.positive) == len(self.negative)\n        else:\n            assert len(self.sentence_1) == len(self.sentence_2) == len(self.label)\n\n    def data2tensor(self, line, type):\n        split_data = line.split('\\t')\n\n        if type == 'train':\n            anchor_sen, positive_sen, negative_sen = split_data\n            \n            anchor = self.bert_tokenizer(anchor_sen, \n                                         truncation=True,\n                                         return_tensors=\"pt\",\n                                         max_length=self.args.max_len,\n                                         padding='max_length')\n            \n            anh_dec_ids = torch.cat([torch.tensor([self.init_token_idx]).unsqueeze(0), anchor['input_ids'][:, :-1]], dim=-1)    \n            anchor['dec_ids'] = anh_dec_ids\n        \n            positive = self.bert_tokenizer(positive_sen, \n                                           truncation=True,\n                                           return_tensors=\"pt\",\n                                           max_length=self.args.max_len,\n                                           padding='max_length')\n\n            pos_dec_ids = torch.cat([torch.tensor([self.init_token_idx]).unsqueeze(0), positive['input_ids'][:, :-1]], dim=-1)\n            positive['dec_ids'] = pos_dec_ids\n            \n            negative = self.bert_tokenizer(negative_sen, \n                                           truncation=True,\n                                           return_tensors=\"pt\",\n                                           max_length=self.args.max_len,\n                                           padding='max_length')\n        \n            neg_dec_ids = torch.cat([torch.tensor([self.init_token_idx]).unsqueeze(0), negative['input_ids'][:, :-1]], dim=-1)\n            negative['dec_ids'] = neg_dec_ids\n            \n            self.anchor.append(anchor)\n            self.positive.append(positive)\n            self.negative.append(negative)\n        \n        else:\n            sentence_1, sentence_2, label = split_data\n    \n            sentence_1 = self.bert_tokenizer(sentence_1, \n                                             truncation=True,\n                                             return_tensors=\"pt\",\n                                             max_length=self.args.max_len,\n                                             padding='max_length')\n            \n            s1_dec_ids = torch.cat([torch.tensor([self.init_token_idx]).unsqueeze(0), sentence_1['input_ids'][:, :-1]], dim=-1)    \n            \n            sentence_1['dec_ids'] = s1_dec_ids\n            \n            sentence_2 = self.bert_tokenizer(sentence_2,\n                                             truncation=True,\n                                             return_tensors=\"pt\",\n                                             max_length=self.args.max_len,\n                                             padding='max_length')\n            s2_dec_ids = torch.cat([torch.tensor([self.init_token_idx]).unsqueeze(0), sentence_2['input_ids'][:, :-1]], dim=-1)\n            \n            sentence_2['dec_ids'] = s2_dec_ids\n            \n            self.sentence_1.append(sentence_1)\n            self.sentence_2.append(sentence_2)\n            self.label.append(float(label.strip())/5.0)\n\n    def __getitem__(self, index):\n\n        if self.type == 'train':\n            inputs = {'anchor': {\n                'source': torch.LongTensor(self.anchor[index]['input_ids']),\n                'attention_mask': self.anchor[index]['attention_mask'],\n                'dec_ids': torch.LongTensor(self.anchor[index]['dec_ids'])\n                                },\n                      'positive': {\n                'source': torch.LongTensor(self.positive[index]['input_ids']),\n                'attention_mask': self.positive[index]['attention_mask'],\n                'dec_ids': torch.LongTensor(self.positive[index]['dec_ids'])\n                                },\n                      'negative': {\n                'source': torch.LongTensor(self.negative[index]['input_ids']),\n                'attention_mask': self.negative[index]['attention_mask'],\n                'dec_ids': torch.LongTensor(self.negative[index]['dec_ids'])\n                                }}\n        else:\n\n            inputs = {'sentence_1': {\n                'source': torch.LongTensor(self.sentence_1[index]['input_ids']),\n                'attention_mask': self.sentence_1[index]['attention_mask'],\n                'dec_ids': torch.LongTensor(self.sentence_1[index]['dec_ids'])\n                                },\n                      'sentence_2': {\n                'source': torch.LongTensor(self.sentence_2[index]['input_ids']),\n                'attention_mask': self.sentence_2[index]['attention_mask'],\n                'dec_ids': torch.LongTensor(self.sentence_2[index]['dec_ids'])\n                                },\n                      'label': {\n                          'value': torch.FloatTensor([self.label[index]])}\n                }\n\n        for key, value in inputs.items():\n            for inner_key, inner_value in value.items():\n                inputs[key][inner_key] = inner_value.squeeze(0)\n                \n        inputs = self.metric.move2device(inputs, self.args.device)\n        \n        return inputs\n\n    def __len__(self):\n        if self.type == 'train':\n            return len(self.anchor)\n        else:\n            return len(self.label)\n\n\n# Get train, valid, test data loader and BERT tokenizer\ndef get_loader(args, metric):\n    \n    tokenizer = AutoTokenizer.from_pretrained(args.model)\n\n    path_to_train_data = args.path_to_data + '/' + args.train_data\n    path_to_valid_data = args.path_to_data + '/' + args.valid_data\n    path_to_test_data = args.path_to_data + '/' + args.test_data\n\n    if args.train == 'True' and args.test == 'False':\n        train_iter = ModelDataLoader(path_to_train_data, args, metric, tokenizer, type_='train')\n        valid_iter = ModelDataLoader(path_to_valid_data, args, metric, tokenizer, type_='valid')\n\n        train_iter.load_data('train')\n        valid_iter.load_data('valid')\n\n        loader = {'train': DataLoader(dataset=train_iter,\n                                      batch_size=args.batch_size,\n                                      shuffle=True),\n                  'valid': DataLoader(dataset=valid_iter,\n                                      batch_size=args.batch_size,\n                                      shuffle=True)}\n\n    elif args.train == 'False' and args.test == 'True':\n        test_iter = ModelDataLoader(path_to_test_data, args, metric, tokenizer, type_='test')\n        test_iter.load_data('test')\n\n        loader = {'test': DataLoader(dataset=test_iter,\n                                     batch_size=args.batch_size,\n                                     shuffle=True)}\n\n    else:\n        loader = None\n\n    return loader, tokenizer\n\n\nif __name__ == '__main__':\n    get_loader('test')\n"
  },
  {
    "path": "KoSentenceT5/main.py",
    "content": "from model.setting import Setting, Arguments\nfrom model.simcse.processor import Processor\n\n\ndef main(args, logger) -> None:\n    processor = Processor(args)\n    config = processor.model_setting()\n    logger.info('Model Setting Complete')\n\n    if args.train == 'True':\n        logger.info('Start Training')\n\n        for epoch in range(args.epochs):\n\n            processor.train(epoch+1)\n\n    if args.test == 'True':\n        logger.info(\"Start Test\")\n\n        processor.test()\n\n        processor.metric.print_size_of_model(config['model'])\n        processor.metric.count_parameters(config['model'])\n\n\nif __name__ == '__main__':\n    args, logger = Setting().run()\n    main(args, logger)\n"
  },
  {
    "path": "KoSentenceT5/model/loss.py",
    "content": "import torch\nimport logging\nimport numpy as np\nimport torch.nn as nn\nfrom model.utils import Metric\nfrom scipy.stats import pearsonr, spearmanr\nfrom sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances\n\nlogger = logging.getLogger(__name__)\n\n\nclass Loss():\n\n    def __init__(self, args):\n        self.args = args\n        self.cos = nn.CosineSimilarity(dim=-1)\n        self.metric = Metric(args)\n\n    def train_loss_fct(self, config, inputs, a, p, n):\n         \n        positive_similarity = self.cos(a.unsqueeze(1), p.unsqueeze(0)) / self.args.temperature\n        negative_similarity = self.cos(a.unsqueeze(1), n.unsqueeze(0)) / self.args.temperature\n        cosine_similarity = torch.cat([positive_similarity, negative_similarity], dim=1).to(self.args.device)\n\n        labels = torch.arange(cosine_similarity.size(0)).long().to(self.args.device)\n\n        loss = config['criterion'](cosine_similarity, labels)\n\n        return loss\n\n    def evaluation_during_training(self, embeddings1, embeddings2, labels, indicator):\n\n        embeddings1 = embeddings1.cpu().numpy()\n        embeddings2 = embeddings2.cpu().numpy()\n        labels = labels['value'].cpu().numpy().flatten()\n\n        cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))\n        manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)\n        euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)\n        dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]\n\n        eval_pearson_cosine, _ = pearsonr(labels, cosine_scores)\n        eval_spearman_cosine, _ = spearmanr(labels, cosine_scores)\n\n        eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances)\n        eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances)\n\n        eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances)\n        eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances)\n\n        eval_pearson_dot, _ = pearsonr(labels, dot_products)\n        eval_spearman_dot, _ = spearmanr(labels, dot_products)\n\n        score = {'eval_pearson_cosine': eval_pearson_cosine,\n                 'eval_spearman_cosine': eval_spearman_cosine,\n                 'eval_pearson_manhattan': eval_pearson_manhattan,\n                 'eval_spearman_manhattan': eval_spearman_manhattan,\n                 'eval_pearson_euclidean': eval_pearson_euclidean,\n                 'eval_spearman_euclidean': eval_spearman_euclidean,\n                 'eval_pearson_dot': eval_pearson_dot,\n                 'eval_spearman_dot': eval_spearman_dot}\n\n        self.metric.update_indicator(indicator, score)\n\n        return max(eval_spearman_cosine, eval_spearman_manhattan, eval_spearman_euclidean, eval_spearman_dot)\n"
  },
  {
    "path": "KoSentenceT5/model/setting.py",
    "content": "import torch\nimport random\nimport logging\nimport numpy as np\nfrom argparse import ArgumentParser\n\n\nclass Arguments():\n\n    def __init__(self):\n        self.parser = ArgumentParser()\n\n    def add_type_of_processing(self):\n        self.add_argument('--opt_level', type=str, default='O1')\n        self.add_argument('--fp16', type=str, default='True')\n        self.add_argument('--train', type=str, default='True')\n        self.add_argument('--test', type=str, default='True')\n        self.add_argument('--multi_gpu', type=str, default='False')\n        self.add_argument('--device', type=str, default=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))\n\n    def add_hyper_parameters(self):\n        self.add_argument('--model', type=str, default='KETI-AIR/ke-t5-base')\n        self.add_argument('--patient', type=int, default=10)\n        self.add_argument('--dropout', type=int, default=0.1)\n        self.add_argument('--max_len', type=int, default=50)\n        self.add_argument('--batch_size', type=int, default=256)\n        self.add_argument('--epochs', type=int, default=2)\n        self.add_argument('--eval_steps', type=int, default=250)\n        self.add_argument('--seed', type=int, default=12)\n        self.add_argument('--lr', type=float, default=0.00005)\n        self.add_argument('--weight_decay', type=float, default=0.1)\n        self.add_argument('--warmup_ratio', type=float, default=0.05)\n        self.add_argument('--temperature', type=float, default=0.05)\n\n    def add_data_parameters(self):\n        self.add_argument('--train_data', type=str, default='train_nli.tsv')\n        self.add_argument('--valid_data', type=str, default='valid_sts.tsv')\n        self.add_argument('--test_data', type=str, default='test_sts.tsv')\n        self.add_argument('--task', type=str, default='NLU')\n        self.add_argument('--path_to_data', type=str, default='./data/')\n        self.add_argument('--path_to_save', type=str, default='./output/')\n        self.add_argument('--path_to_saved_model', type=str, default='./output/')\n        self.add_argument('--ckpt', type=str, default='best_checkpoint.pt')\n\n    def print_args(self, args):\n        for idx, (key, value) in enumerate(args.__dict__.items()):\n            if idx == 0:print(\"argparse{\\n\", \"\\t\", key, \":\", value)\n            elif idx == len(args.__dict__) - 1:print(\"\\t\", key, \":\", value, \"\\n}\")\n            else:print(\"\\t\", key, \":\", value)\n\n    def add_argument(self, *args, **kw_args):\n        return self.parser.add_argument(*args, **kw_args)\n\n    def parse(self):\n        args = self.parser.parse_args()\n        self.print_args(args)\n\n        return args\n\n\nclass Setting():\n\n    def set_logger(self):\n\n        _logger = logging.getLogger()\n        formatter = logging.Formatter(\n            '[%(levelname)s] %(asctime)s [ %(message)s ] | file::%(filename)s | line::%(lineno)s')\n\n        stream_handler = logging.StreamHandler()\n        stream_handler.setFormatter(formatter)\n\n        _logger.addHandler(stream_handler)\n        _logger.setLevel(logging.DEBUG)\n\n        return _logger\n\n    def set_seed(self, args):\n\n        seed = args.seed\n\n        random.seed(seed)\n        np.random.seed(seed)\n\n        torch.manual_seed(seed)\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n        torch.cuda.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def run(self):\n\n        parser = Arguments()\n        parser.add_type_of_processing()\n        parser.add_hyper_parameters()\n        parser.add_data_parameters()\n\n        args = parser.parse()\n        logger = self.set_logger()\n        self.set_seed(args)\n\n        return args, logger\n"
  },
  {
    "path": "KoSentenceT5/model/simcse/kost5.py",
    "content": "import torch\nfrom torch import nn\nfrom transformers import BartForSequenceClassification, AutoModel\n\nclass KoSentenceT5(nn.Module):\n    def __init__(self, model):\n        super(KoSentenceT5, self).__init__()\n        self.model = AutoModel.from_pretrained(model)\n    \n    def forward(self, config, inputs, mode):\n\n        if mode == 'train':\n            \n            anchor_pooler = self.model(input_ids=inputs['anchor']['source'],\n                                       attention_mask=inputs['anchor']['attention_mask'],\n                                       decoder_input_ids=inputs['anchor']['dec_ids']\n                                       )\n        \n            positive_pooler = self.model(input_ids=inputs['positive']['source'],\n                                         attention_mask=inputs['positive']['attention_mask'],\n                                         decoder_input_ids=inputs['positive']['dec_ids']\n                                         )\n        \n            negative_pooler = self.model(input_ids=inputs['negative']['source'],\n                                         attention_mask=inputs['negative']['attention_mask'],\n                                         decoder_input_ids=inputs['negative']['dec_ids']\n                                         )\n            \n            return anchor_pooler, positive_pooler, negative_pooler\n\n        else:\n            sentence_1_pooler = self.model(input_ids=inputs['sentence_1']['source'],\n                                           attention_mask=inputs['sentence_1']['attention_mask'],\n                                           decoder_input_ids=inputs['sentence_1']['dec_ids']\n                                           )\n            \n            sentence_2_pooler = self.model(input_ids=inputs['sentence_2']['source'],\n                                           attention_mask=inputs['sentence_2']['attention_mask'],\n                                           decoder_input_ids=inputs['sentence_2']['dec_ids']\n                                           )\n    \n            return sentence_1_pooler, sentence_2_pooler\n\n    def encode(self, inputs, device):\n    \n        embeddings = self.model(input_ids=inputs['source'].to(device),\n                                attention_mask=inputs['attention_mask'].to(device),\n                                )\n\n        return ((embeddings * inputs['attention_mask'].unsqueeze(-1)).sum(1) / inputs['attention_mask'].sum(-1).unsqueeze(-1))\n"
  },
  {
    "path": "KoSentenceT5/model/simcse/processor.py",
    "content": "import os\nimport logging\nfrom apex import amp\nimport torch.nn as nn\nfrom tqdm import tqdm\nimport torch.quantization\nimport torch.optim as optim\nfrom model.loss import Loss\nfrom model.utils import Metric\nfrom accelerate import Accelerator\nfrom transformers import AutoModel\nfrom model.simcse.kost5 import KoSentenceT5\nfrom data.dataloader import get_loader\nfrom transformers import get_linear_schedule_with_warmup\n\nlogger = logging.getLogger(__name__)\n\n\nclass Processor():\n\n    def __init__(self, args):\n        self.args = args\n        self.config = None\n        self.metric = Metric(args)\n        self.loss = Loss(args)\n        self.total_steps = 0\n        self.model_checker = {'early_stop': False,\n                              'early_stop_patient': 0,\n                              'best_valid_score': 0}\n        self.dev_progress = {'score': 0, 'iter': 0}\n        self.model_progress = {'loss': 0, 'iter': 0}\n\n    def run(self, inputs, indicator=None, type=None):\n\n        if type == 'train':\n            anchor_embeddings, positive_embeddings, negative_embeddings = self.config['model'](self.config, inputs, type)\n            loss = self.loss.train_loss_fct(self.config,\n                                            inputs, \n                                            anchor_embeddings, \n                                            positive_embeddings, \n                                            negative_embeddings)\n            return loss\n        else:\n            sentence_1_embeddings, sentence_2_embeddings = self.config['model'](self.config, inputs, type)\n            \n            score = self.loss.evaluation_during_training(sentence_1_embeddings,\n                                                         sentence_2_embeddings,\n                                                         inputs['label'],\n                                                         indicator)\n            return score\n\n    def progress(self, loss):\n        self.model_progress['loss'] += loss\n        self.model_progress['iter'] += 1\n\n    def progress_validation(self, score):\n        self.dev_progress['score'] += score\n        self.dev_progress['iter'] += 1\n\n    def return_value(self):\n        loss = self.model_progress['loss'].data.cpu().numpy() / self.model_progress['iter']\n        acc = self.model_progress['acc'].data.cpu().numpy() / self.model_progress['iter']\n\n        return loss, acc\n\n    def get_object(self, tokenizer, model):\n\n        no_decay = ['bias', 'LayerNorm.weight']\n        optimizer_grouped_parameters = [\n            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n             'weight_decay': self.args.weight_decay},\n            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n             'weight_decay': 0.0}\n        ]\n\n        criterion = nn.CrossEntropyLoss()\n        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.args.lr)\n\n        return criterion, optimizer\n\n    def get_scheduler(self, optim, train_loader):\n        train_total = len(train_loader) * self.args.epochs\n        scheduler = get_linear_schedule_with_warmup(optim,\n                                                    num_warmup_steps=self.args.warmup_ratio * train_total,\n                                                    num_training_steps=train_total)\n\n        return scheduler, train_total\n\n    def model_setting(self):\n        accelerator = Accelerator(fp16=False)\n\n        loader, tokenizer = get_loader(self.args, self.metric)\n\n        model = KoSentenceT5(AutoModel.from_pretrained(self.args.model))\n        vocab = tokenizer.get_vocab()\n        model.model.resize_token_embeddings(len(vocab))\n\n        if self.args.multi_gpu == 'True':\n            model = nn.DataParallel(model, output_device=0)\n        model.to(self.args.device)\n        \n        criterion, optimizer = self.get_object(tokenizer, model)\n\n        if self.args.train == 'True':\n            scheduler, total_steps = self.get_scheduler(optimizer, loader['train'])\n            self.total_steps = total_steps\n        else:\n            scheduler = None\n\n        config = {'loader': loader,\n                  'optimizer': optimizer,\n                  'criterion': criterion,\n                  'scheduler': scheduler,\n                  'tokenizer': tokenizer,\n                  'accelerator': accelerator,\n                  'args': self.args,\n                  'model': model}\n\n        config['model'], config['optimizer'] = accelerator.prepare(model, optimizer)\n\n        self.config = config\n\n        return self.config\n\n    def train(self, epoch):\n        self.config['model'].train()\n        \n        train_loader = self.config['accelerator'].prepare(self.config['loader']['train'])\n        for step, batch in enumerate(tqdm(train_loader)):\n            self.config['optimizer'].zero_grad()\n\n            inputs = batch\n\n            loss = self.run(inputs, type='train')\n            loss = torch.mean(loss)\n\n            if self.args.fp16 == 'True' and self.args.multi_gpu == 'False':\n                with amp.scale_loss(loss, self.config['optimizer']) as scaled_loss:\n                    scaled_loss.backward()\n            else:\n                self.config['accelerator'].backward(loss)\n\n            self.config['optimizer'].step()\n            self.config['scheduler'].step()\n\n            self.progress(loss.data)\n\n            if self.model_progress['iter'] % self.args.eval_steps == 0 or self.model_progress['iter'] == self.total_steps:\n                valid_score = self.valid()\n\n                performance = {'tl': loss, 'vs': valid_score, 'ep': epoch, 'step': self.model_progress['iter']}\n\n                self.metric.save_model(self.config, performance, self.model_checker)\n\n    def valid(self):\n        self.config['model'].eval()\n        self.dev_progress = self.dev_progress.fromkeys(self.dev_progress, 0)\n\n        score_indicator = {'eval_pearson_cosine': 0,\n                           'eval_spearman_cosine': 0,\n                           'eval_pearson_manhattan': 0,\n                           'eval_spearman_manhattan': 0,\n                           'eval_pearson_euclidean': 0,\n                           'eval_spearman_euclidean': 0,\n                           'eval_pearson_dot': 0,\n                           'eval_spearman_dot': 0}\n        \n        valid_loader = self.config['accelerator'].prepare(self.config['loader']['valid'])\n        with torch.no_grad():\n            for step, batch in enumerate(valid_loader):\n                inputs = batch\n                score = self.run(inputs, indicator=score_indicator, type='valid')\n                self.progress_validation(score)\n\n        score = self.metric.cal_dev_score(self.dev_progress, score_indicator)\n\n        return score\n\n    def test(self):\n        self.config['model'].load_state_dict(torch.load(self.args.path_to_saved_model))\n        self.config['model'].eval()\n        \n        self.dev_progress = self.dev_progress.fromkeys(self.dev_progress, 0)\n        \n        score_indicator = {'eval_pearson_cosine': 0,\n                           'eval_spearman_cosine': 0,\n                           'eval_pearson_manhattan': 0,\n                           'eval_spearman_manhattan': 0,\n                           'eval_pearson_euclidean': 0,\n                           'eval_spearman_euclidean': 0,\n                           'eval_pearson_dot': 0,\n                           'eval_spearman_dot': 0}\n\n        with torch.no_grad():\n            for step, batch in enumerate(self.config['loader']['test']):\n                inputs = batch\n                score = self.run(inputs, indicator=score_indicator, type='test')\n\n                self.progress_validation(score)\n\n        logger.info('### TEST SCORE ###')\n        score = self.metric.cal_dev_score(self.dev_progress, score_indicator)\n"
  },
  {
    "path": "KoSentenceT5/model/utils.py",
    "content": "import os\nimport torch\nimport logging\nfrom tensorboardX import SummaryWriter\n\nlogger = logging.getLogger(__name__)\nwriter = SummaryWriter()\n\n\nclass Metric():\n\n    def __init__(self, args):\n        self.args = args\n\n    def get_lr(self, optimizer):\n        return optimizer.state_dict()['param_groups'][0]['lr']\n\n    def count_parameters(self, model):\n        print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n\n    def cal_acc(self, yhat, y):\n        with torch.no_grad():\n            yhat = yhat.max(dim=-1)[1]  # [0]: max value, [1]: index of max value\n            acc = (yhat == y).float().mean()\n\n        return acc\n\n    def cal_time(self, start_time, end_time):\n        elapsed_time = end_time - start_time\n        elapsed_mins = int(elapsed_time / 60)\n        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n\n        return elapsed_mins, elapsed_secs\n\n    def cal_dev_score(self, score, indicator):\n        validation_score = score['score'] / score['iter']\n        for key, value in indicator.items():\n            indicator[key] /= score['iter']\n\n        print(\"\\n\\nCosine-Similarity :\\tPearson: {:.4f}\\tSpearman: {:.4f}\".format(\n            indicator['eval_pearson_cosine'], indicator['eval_spearman_cosine']))\n        print(\"Manhattan-Distance:\\tPearson: {:.4f}\\tSpearman: {:.4f}\".format(\n            indicator['eval_pearson_manhattan'], indicator['eval_spearman_manhattan']))\n        print(\"Euclidean-Distance:\\tPearson: {:.4f}\\tSpearman: {:.4f}\".format(\n            indicator['eval_pearson_euclidean'], indicator['eval_spearman_euclidean']))\n        print(\"Dot-Product-Similarity:\\tPearson: {:.4f}\\tSpearman: {:.4f}\\n\".format(\n            indicator['eval_pearson_dot'], indicator['eval_spearman_dot']))\n\n        return validation_score\n\n    def update_indicator(self, indicator, score):\n        for key, value in indicator.items():\n            if key == 'eval_spearman_cosine':\n                indicator[key] += score['eval_spearman_cosine']\n            elif key == 'eval_pearson_cosine':\n                indicator[key] += score['eval_pearson_cosine']\n            elif key == 'eval_spearman_manhattan':\n                indicator[key] += score['eval_spearman_manhattan']\n            elif key == 'eval_pearson_manhattan':\n                indicator[key] += score['eval_pearson_manhattan']\n            elif key == 'eval_spearman_euclidean':\n                indicator[key] += score['eval_spearman_euclidean']\n            elif key == 'eval_pearson_euclidean':\n                indicator[key] += score['eval_pearson_euclidean']\n            elif key == 'eval_spearman_dot':\n                indicator[key] += score['eval_spearman_dot']\n            elif key == 'eval_pearson_dot':\n                indicator[key] += score['eval_pearson_dot']\n\n    def draw_graph(self, cp):\n        writer.add_scalars('loss_graph', {'train': cp['tl'], 'valid': cp['vl']}, cp['ep'])\n        writer.add_scalars('acc_graph', {'train': cp['tma'], 'valid': cp['vma']}, cp['ep'])\n\n    def performance_check(self, cp, config):\n        print(f'\\t==Epoch: {cp[\"ep\"] + 1:02} | Epoch Time: {cp[\"epm\"]}m {cp[\"eps\"]}s==')\n        print(f'\\t==Train Loss: {cp[\"tl\"]:.4f} | Train acc: {cp[\"tma\"]:.4f}==')\n        print(f'\\t==Valid Loss: {cp[\"vl\"]:.4f} | Valid acc: {cp[\"vma\"]:.4f}==')\n        print(f'\\t==Epoch latest LR: {self.get_lr(config[\"optimizer\"]):.9f}==\\n')\n\n    def print_size_of_model(self, model):\n        torch.save(model.state_dict(), \"temp.p\")\n        print('Size (MB):', os.path.getsize(\"temp.p\") / 1e6)\n        os.remove('temp.p')\n\n    def move2device(self, sample, device):\n        if len(sample) == 0:\n            return {}\n\n        def _move_to_device(maybe_tensor, device):\n            if torch.is_tensor(maybe_tensor):\n                return maybe_tensor.to(device)\n            elif isinstance(maybe_tensor, dict):\n                return {\n                    key: _move_to_device(value, device)\n                    for key, value in maybe_tensor.items()\n                    }\n            elif isinstance(maybe_tensor, list):\n                return [_move_to_device(x, device) for x in maybe_tensor]\n            elif isinstance(maybe_tensor, tuple):\n                return [_move_to_device(x, device) for x in maybe_tensor]\n            else:\n                return maybe_tensor\n\n        return _move_to_device(sample, device)\n\n    def save_model(self, config, cp, pco):\n        if not os.path.exists(config['args'].path_to_save):\n            os.makedirs(config['args'].path_to_save)\n\n        sorted_path = config['args'].path_to_save + \"kosimcse-\" + config['args'].model.replace(\"/\", \"-\") + '.pt'\n        if cp['vs'] > pco['best_valid_score']:\n            # pco['early_stop_patient'] = 0\n            pco['best_valid_score'] = cp['vs']\n            \n            unwrapped_model = config['accelerator'].unwrap_model(config['model'])\n            config['accelerator'].save(unwrapped_model.state_dict(), sorted_path)\n\n            #state = {'model': config['model'].state_dict(),\n            #         'optimizer': config['optimizer'].state_dict()}\n\n            #torch.save(state, sorted_path)\n            print(f'\\t## SAVE {sorted_path} |'\n                  f' valid_score: {cp[\"vs\"]:.4f} |'\n                  f' epochs: {cp[\"ep\"]} |'\n                  f' steps: {cp[\"step\"]} ##\\n')\n\n        # self.draw_graph(cp)\n        # self.performance_check(cp, config)\n\n\ndef pytorch_cos_sim(a, b):\n    \"\"\"\n    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.\n    This function can be used as a faster replacement for 1-scipy.spatial.distance.cdist(a,b)\n    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])\n    \"\"\"\n    if not isinstance(a, torch.Tensor):\n        a = torch.tensor(a)\n\n    if not isinstance(b, torch.Tensor):\n        b = torch.tensor(b)\n\n    if len(a.shape) == 1:\n        a = a.unsqueeze(0)\n\n    if len(b.shape) == 1:\n        b = b.unsqueeze(0)\n\n    a_norm = a / a.norm(dim=1)[:, None]\n    b_norm = b / b.norm(dim=1)[:, None]\n    return torch.mm(a_norm, b_norm.transpose(0, 1))\n"
  },
  {
    "path": "KoSentenceT5/run_example.sh",
    "content": "#!/bin/bash\n\nCUDA_VISIBLE_DEVICES=0,1 python main.py \\\n  --model etri-t5 \\\n  --multi_gpu True \\\n  --test False \\\n  --max_len 110 \\\n  --batch_size 64 \\\n  --epochs 2 \\\n  --eval_steps 125 \\\n  --lr 0.0001 \\\n  --warmup_ratio 0.01 \\\n  --temperature 0.05 \\\n  --path_to_data ../Dataset/ \\\n  --train_data train_nli.tsv \\\n  --valid_data valid_sts.tsv\n\nCUDA_VISIBLE_DEVICES=1 python main.py \\\n  --model etri-t5 \\\n  --train False \\\n  --test True \\\n  --max_len 110 \\\n  --batch_size 64 \\\n  --temperature 0.05 \\\n  --path_to_data ../Dataset/ \\\n  --test_data test_sts.tsv \\\n  --path_to_saved_model output/\n"
  },
  {
    "path": "KoSimCSE/README.md",
    "content": "# KoSimCSE\n[[Github]](https://github.com/princeton-nlp/SimCSE) Official implementation of SimCSE. <br>\nKoSimCSE : Korean Sentence Embeddings using contrastive learning.\n\n## Quick start\n- If you want to do inference quickly, download the pre-trained models and then you can start some downstream tasks.\n```\nbash get_model_checkpoint.sh\npython SemanticSearch.py\n```\n\n## Training \n- Before training or evaluation, please download the datasets by running\n```\nbash get_model_dataset.sh\n```\n### Train KoSimCSE (Supervised Only)\n  ```\n  python main.py \\\n    --model klue/bert-base \\\n    --test False \\\n    --max_len 50 \\\n    --batch_size 256 \\\n    --epochs 2 \\\n    --eval_steps 125 \\\n    --lr 0.0001 \\\n    --warmup_ratio 0.1 \\\n    --temperature 0.05 \\\n    --path_to_data ../Dataset/ \\\n    --train_data train_nli.tsv \\\n    --valid_data valid_sts.tsv\n  ```\n### Evaluation\n  ```\n  python main.py \\\n    --model klue/bert-base \\\n    --train False \\\n    --test True \\\n    --max_len 50 \\\n    --batch_size 256 \\\n    --temperature 0.05 \\\n    --path_to_data ../Dataset/ \\\n    --test_data test_sts.tsv \\\n    --path_to_saved_model output/kosimcse-klue-bert-base.pt\n  ```\n\n### Run Examples\n```\nbash run_example.sh\n```\n### Hyperparameters\n- Train KoSimCSE (BERT BASE)\n  1. Pooling Method: [CLS] strategy\n  2. Batch Size: 256\n  3. Evaluation Steps: 125\n  4. Epochs: 2\n  5. Token Max Length: 128\n  6. Learning Rate: 0.0001\n  7. Warmup Ratio: 0.1\n  8. Temperature: 0.05\n  \n- Train KoSimCSE (RoBERTa BASE)\n  1. Pooling Method: [CLS] strategy\n  2. Batch Size: 256\n  3. Evaluation Steps: 125\n  4. Epochs: 2\n  5. Token Max Length: 128\n  6. Learning Rate: 0.0001\n  7. Warmup Ratio: 0.05\n  8. Temperature: 0.05\n\n### Semantic Search\n```\npython SemanticSearch.py\n```\n```python\nfrom model.simcse.bert import BERT\nfrom transformers import AutoModel, AutoTokenizer\n\ndef main():\n    model = BERT(AutoModel.from_pretrained('BM-K/KoSimCSE-roberta'))\n    tokenizer = AutoTokenizer.from_pretrained('BM-K/KoSimCSE-roberta')\n\n    model.to(device)\n    model.eval()\n   \n    model, tokenizer, device = example_model_setting(model_name)\n\n    # Corpus with example sentences\n    corpus = ['한 남자가 음식을 먹는다.',\n              '한 남자가 빵 한 조각을 먹는다.',\n              '그 여자가 아이를 돌본다.',\n              '한 남자가 말을 탄다.',\n              '한 여자가 바이올린을 연주한다.',\n              '두 남자가 수레를 숲 속으로 밀었다.',\n              '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.',\n              '원숭이 한 마리가 드럼을 연주한다.',\n              '치타 한 마리가 먹이 뒤에서 달리고 있다.']\n\n    inputs_corpus = convert_to_tensor(corpus, tokenizer, device)\n\n    corpus_embeddings = model.encode(inputs_corpus, device)\n\n    # Query sentences:\n    queries = ['한 남자가 파스타를 먹는다.',\n               '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.',\n               '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\n    # Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity\n    top_k = 5\n    for query in queries:\n        query_embedding = model.encode(convert_to_tensor([query], tokenizer, device), device)\n\n        cos_scores = pytorch_cos_sim(query_embedding, corpus_embeddings)[0]\n        cos_scores = cos_scores.cpu().detach().numpy()\n\n        top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]\n\n        print(\"\\n\\n======================\\n\\n\")\n        print(\"Query:\", query)\n        print(\"\\nTop 5 most similar sentences in corpus:\")\n\n        for idx in top_results[0:top_k]:\n            print(corpus[idx].strip(), \"(Score: %.4f)\" % (cos_scores[idx]))\n```\n\n- Results are as follows :\n\n```\n\nQuery: 한 남자가 파스타를 먹는다.\n\nTop 5 most similar sentences in corpus:\n한 남자가 음식을 먹는다. (Score: 0.6141)\n한 남자가 빵 한 조각을 먹는다. (Score: 0.5952)\n한 남자가 말을 탄다. (Score: 0.1231)\n한 남자가 담으로 싸인 땅에서 백마를 타고 있다. (Score: 0.0752)\n두 남자가 수레를 숲 솦으로 밀었다. (Score: 0.0486)\n\n\n======================\n\n\nQuery: 고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.\n\nTop 5 most similar sentences in corpus:\n원숭이 한 마리가 드럼을 연주한다. (Score: 0.6656)\n치타 한 마리가 먹이 뒤에서 달리고 있다. (Score: 0.2988)\n한 여자가 바이올린을 연주한다. (Score: 0.1566)\n한 남자가 말을 탄다. (Score: 0.1112)\n한 남자가 담으로 싸인 땅에서 백마를 타고 있다. (Score: 0.0262)\n\n\n======================\n\n\nQuery: 치타가 들판을 가로 질러 먹이를 쫓는다.\n\nTop 5 most similar sentences in corpus:\n치타 한 마리가 먹이 뒤에서 달리고 있다. (Score: 0.7570)\n두 남자가 수레를 숲 솦으로 밀었다. (Score: 0.3658)\n원숭이 한 마리가 드럼을 연주한다. (Score: 0.3583)\n한 남자가 말을 탄다. (Score: 0.0505)\n그 여자가 아이를 돌본다. (Score: -0.0087)\n```\n\n### Clustering\n```python\nimport torch\n\nfrom tqdm import tqdm\nfrom sklearn.cluster import KMeans\nfrom transformers import (\n    AutoModel,\n    AutoTokenizer\n)\n\ndef encode(model=None,\n           tokenizer=None,\n           corpus=None,\n           ):\n\n    tokenized_corpus = tokenizer(corpus,\n                                 truncation=True,\n                                 return_tensors='pt',\n                                 max_length=token_max_len,\n                                 padding='max_length')\n\n    embeddings, _ = model(input_ids=tokenized_corpus['input_ids'].to(device),\n                          token_type_ids=tokenized_corpus['token_type_ids'].to(device),\n                          attention_mask=tokenized_corpus['attention_mask'].to(device),\n                          return_dict=False)\n\n    return embeddings[:, 0].cpu().detach()\n\ndef get_model():\n\n    model = AutoModel.from_pretrained('BM-K/KoSimCSE-roberta-multitask')\n    tokenizer = AutoTokenizer.from_pretrained('BM-K/KoSimCSE-roberta-multitask')\n\n    model.eval()\n\n    return model.to(device), tokenizer\n\ndef get_cluster(corpus_embeddings\n                ):\n\n    clustering_model = KMeans(n_clusters=num_clusters)\n    clustering_model.fit(corpus_embeddings)\n\n    return clustering_model.labels_\n\ndef main():\n    # Corpus with example sentences\n    corpus = ['한 남자가 음식을 먹는다.',\n              '한 남자가 빵 한 조각을 먹는다.',\n              '그 여자가 아이를 돌본다.',\n              '한 남자가 말을 탄다.',\n              '한 여자가 바이올린을 연주한다.',\n              '두 남자가 수레를 숲 솦으로 밀었다.',\n              '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.',\n              '원숭이 한 마리가 드럼을 연주한다.',\n              '치타 한 마리가 먹이 뒤에서 달리고 있다.',\n              '한 남자가 파스타를 먹는다.',\n              '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.',\n              '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\n    n_corpus = len(corpus)\n\n    model, tokenizer = get_model()\n\n    corpus_embeddings = torch.tensor([])\n    for start_idx in tqdm(range(0, n_corpus, embedding_batch)):\n        batch_corps = corpus[start_idx:start_idx+embedding_batch]\n        batch_embedding = encode(model, tokenizer, batch_corps)\n        corpus_embeddings = torch.cat([corpus_embeddings, batch_embedding], dim=0)\n\n    assert n_corpus == corpus_embeddings.size(0)\n\n    cluster_assignment = get_cluster(corpus_embeddings)\n\n    clustered_sentences = [[] for _ in range(num_clusters)]\n    for sentence_id, cluster_id in enumerate(cluster_assignment):\n        clustered_sentences[cluster_id].append(corpus[sentence_id])\n\n    for i, cluster in enumerate(clustered_sentences):\n        print(\"Cluster \", i + 1)\n        print(cluster)\n        print(\"\")\n\nif __name__ == '__main__':\n    num_clusters = 5\n    token_max_len = 50\n    embedding_batch = 3\n    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n\n    main()\n```\n\n- Results are as follows :\n\n```\n\nCluster  1\n['한 남자가 음식을 먹는다.', '한 남자가 빵 한 조각을 먹는다.', '한 남자가 파스타를 먹는다.']\n\nCluster  2\n['한 여자가 바이올린을 연주한다.', '원숭이 한 마리가 드럼을 연주한다.', '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.']\n\nCluster  3\n['한 남자가 말을 탄다.', '두 남자가 수레를 숲 솦으로 밀었다.', '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.']\n\nCluster  4\n['그 여자가 아이를 돌본다.']\n\nCluster  5\n['치타 한 마리가 먹이 뒤에서 달리고 있다.', '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\n```\n"
  },
  {
    "path": "KoSimCSE/SemanticSearch.py",
    "content": "import numpy as np\nfrom model.utils import pytorch_cos_sim\nfrom data.dataloader import convert_to_tensor, example_model_setting\n\n\ndef main():\n    model_name = 'klue/bert-base'\n    model_ckpt = '../Checkpoint/KoSimCSE/kosimcse-klue-bert-base.pt'\n    model, tokenizer, device = example_model_setting(model_ckpt, model_name)\n\n    # Corpus with example sentences\n    corpus = ['한 남자가 음식을 먹는다.',\n              '한 남자가 빵 한 조각을 먹는다.',\n              '그 여자가 아이를 돌본다.',\n              '한 남자가 말을 탄다.',\n              '한 여자가 바이올린을 연주한다.',\n              '두 남자가 수레를 숲 속으로 밀었다.',\n              '한 남자가 담으로 싸인 땅에서 백마를 타고 있다.',\n              '원숭이 한 마리가 드럼을 연주한다.',\n              '치타 한 마리가 먹이 뒤에서 달리고 있다.']\n\n    inputs_corpus = convert_to_tensor(corpus, tokenizer, device)\n\n    corpus_embeddings = model.encode(inputs_corpus, device)\n\n    # Query sentences:\n    queries = ['한 남자가 파스타를 먹는다.',\n               '고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.',\n               '치타가 들판을 가로 질러 먹이를 쫓는다.']\n\n    # Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity\n    top_k = 5\n    for query in queries:\n        query_embedding = model.encode(convert_to_tensor([query], tokenizer, device), device)\n        \n        cos_scores = pytorch_cos_sim(query_embedding, corpus_embeddings)[0]\n        cos_scores = cos_scores.cpu().detach().numpy()\n\n        top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]\n\n        print(\"\\n\\n======================\\n\\n\")\n        print(\"Query:\", query)\n        print(\"\\nTop 5 most similar sentences in corpus:\")\n\n        for idx in top_results[0:top_k]:\n            print(corpus[idx].strip(), \"(Score: %.4f)\" % (cos_scores[idx]))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "KoSimCSE/apex/RNN/README.md",
    "content": "Under construction...\n"
  },
  {
    "path": "KoSimCSE/apex/RNN/RNNBackend.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\n\nimport torch.nn.functional as F\n\nimport math\n\n\ndef is_iterable(maybe_iterable):\n    return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)\n\n\ndef flatten_list(tens_list):\n    \"\"\"\n    flatten_list\n    \"\"\"\n    if not is_iterable(tens_list):\n        return tens_list\n    \n    return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )\n\n    \n#These modules always assumes batch_first\nclass bidirectionalRNN(nn.Module):\n    \"\"\"\n    bidirectionalRNN\n    \"\"\"\n    def __init__(self, inputRNN, num_layers=1, dropout = 0):\n        super(bidirectionalRNN, self).__init__()\n        self.dropout = dropout\n        self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)\n        self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)\n        self.rnns = nn.ModuleList([self.fwd, self.bckwrd])\n        \n    #collect hidden option will return all hidden/cell states from entire RNN\n    def forward(self, input, collect_hidden=False):\n        \"\"\"\n        forward()\n        \"\"\"\n        seq_len = input.size(0)\n        bsz = input.size(1)\n\n        fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))\n        bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))\n        \n        output = torch.cat( [fwd_out, bckwrd_out], -1 )\n        hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )\n\n        return output, hiddens\n\n    def reset_parameters(self):\n        \"\"\"\n        reset_parameters()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_parameters()\n        \n    def init_hidden(self, bsz):\n        \"\"\"\n        init_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_hidden(bsz)\n\n    def detach_hidden(self):\n        \"\"\"\n        detach_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.detachHidden()\n        \n    def reset_hidden(self, bsz):\n        \"\"\"\n        reset_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_hidden(bsz)\n\n    def init_inference(self, bsz):    \n        \"\"\"\n        init_inference()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_inference(bsz)\n\n   \n#assumes hidden_state[0] of inputRNN is output hidden state\n#constructor either takes an RNNCell or list of RNN layers\nclass stackedRNN(nn.Module):        \n    \"\"\"\n    stackedRNN\n    \"\"\"\n    def __init__(self, inputRNN, num_layers=1, dropout=0):\n        super(stackedRNN, self).__init__()\n        \n        self.dropout = dropout\n        \n        if isinstance(inputRNN, RNNCell):\n            self.rnns = [inputRNN]\n            for i in range(num_layers-1):\n                self.rnns.append(inputRNN.new_like(inputRNN.output_size))\n        elif isinstance(inputRNN, list):\n            assert len(inputRNN) == num_layers, \"RNN list length must be equal to num_layers\"\n            self.rnns=inputRNN\n        else:\n            raise RuntimeError()\n        \n        self.nLayers = len(self.rnns)\n        \n        self.rnns = nn.ModuleList(self.rnns)\n\n\n    '''\n    Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])\n    If collect hidden will also return Tuple(\n        [n_hidden_states][sequence steps] Tensor([layer][batch size][features])\n    )\n    If not collect hidden will also return Tuple(\n        [n_hidden_states] Tensor([layer][batch size][features])\n    '''\n    def forward(self, input, collect_hidden=False, reverse=False):\n        \"\"\"\n        forward()\n        \"\"\"\n        seq_len = input.size(0)\n        bsz = input.size(1)\n        inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)\n\n        hidden_states = [[] for i in range(self.nLayers)]\n        outputs = []\n\n        for seq in inp_iter:\n            for layer in range(self.nLayers):\n\n                if layer == 0:\n                    prev_out = input[seq]\n                    \n                outs = self.rnns[layer](prev_out)\n\n                if collect_hidden:\n                    hidden_states[layer].append(outs)\n                elif seq == seq_len-1:\n                    hidden_states[layer].append(outs)\n                    \n                prev_out = outs[0]\n\n            outputs.append(prev_out)\n\n        if reverse:\n            outputs = list(reversed(outputs))\n        '''\n        At this point outputs is in format:\n        list( [seq_length] x Tensor([bsz][features]) )\n        need to convert it to:\n        list( Tensor([seq_length][bsz][features]) )\n        '''\n        output = flatten_list(outputs)\n\n        '''\n        hidden_states at this point is in format:\n        list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )\n        need to convert it to:\n          For not collect hidden:\n            list( [hidden_states] x Tensor([layer][bsz][features]) )\n          For collect hidden:\n            list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )\n        '''\n        if not collect_hidden:\n            seq_len = 1\n        n_hid = self.rnns[0].n_hidden_states\n        new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]\n\n\n        for i in range(n_hid):\n            for j in range(seq_len):\n                for k in range(self.nLayers):\n                    new_hidden[i][j][k] = hidden_states[k][j][i]\n\n        hidden_states = new_hidden\n        #Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )\n        #Reverse seq_length if reverse\n        if reverse:\n            hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)\n\n        #flatten layer dimension into tensor\n        hiddens = list( list(\n            flatten_list(seq) for seq in hidden )\n                        for hidden in hidden_states )\n        \n        #Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )\n        #Remove seq_length dimension if not collect_hidden\n        if not collect_hidden:\n            hidden_states = list( entry[0] for entry in hidden_states)\n        return output, hidden_states\n    \n    def reset_parameters(self):\n        \"\"\"\n        reset_parameters()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_parameters()\n        \n    def init_hidden(self, bsz):\n        \"\"\"\n        init_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_hidden(bsz)\n\n    def detach_hidden(self):\n        \"\"\"\n        detach_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.detach_hidden()\n        \n    def reset_hidden(self, bsz):\n        \"\"\"\n        reset_hidden()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.reset_hidden(bsz)\n\n    def init_inference(self, bsz):    \n        \"\"\" \n        init_inference()\n        \"\"\"\n        for rnn in self.rnns:\n            rnn.init_inference(bsz)\n\nclass RNNCell(nn.Module):\n    \"\"\" \n    RNNCell \n    gate_multiplier is related to the architecture you're working with\n    For LSTM-like it will be 4 and GRU-like will be 3.\n    Always assumes input is NOT batch_first.\n    Output size that's not hidden size will use output projection\n    Hidden_states is number of hidden states that are needed for cell\n    if one will go directly to cell as tensor, if more will go as list\n    \"\"\"\n    def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):\n        super(RNNCell, self).__init__()\n\n        self.gate_multiplier = gate_multiplier\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.cell = cell\n        self.bias = bias\n        self.output_size = output_size\n        if output_size is None:\n            self.output_size = hidden_size\n\n        self.gate_size = gate_multiplier * self.hidden_size\n        self.n_hidden_states = n_hidden_states\n\n        self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))\n        self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size))\n\n        #Check if there's recurrent projection\n        if(self.output_size != self.hidden_size):\n            self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))\n\n        self.b_ih = self.b_hh = None\n        if self.bias:\n            self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))\n            self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))\n            \n        #hidden states for forward\n        self.hidden = [ None for states in range(self.n_hidden_states)]\n\n        self.reset_parameters()\n\n    def new_like(self, new_input_size=None):\n        \"\"\"\n        new_like()\n        \"\"\"\n        if new_input_size is None:\n            new_input_size = self.input_size\n            \n        return type(self)(self.gate_multiplier,\n                       new_input_size,\n                       self.hidden_size,\n                       self.cell,\n                       self.n_hidden_states,\n                       self.bias,\n                       self.output_size)\n\n    \n    #Use xavier where we can (weights), otherwise use uniform (bias)\n    def reset_parameters(self, gain=1):\n        \"\"\"\n        reset_parameters()\n        \"\"\"\n        stdev = 1.0 / math.sqrt(self.hidden_size)\n        for param in self.parameters():\n            param.data.uniform_(-stdev, stdev)\n    '''\n    Xavier reset:\n    def reset_parameters(self, gain=1):\n        stdv = 1.0 / math.sqrt(self.gate_size)\n\n        for param in self.parameters():\n            if (param.dim() > 1):\n                torch.nn.init.xavier_normal(param, gain)\n            else:\n                param.data.uniform_(-stdv, stdv)\n    '''\n    def init_hidden(self, bsz):\n        \"\"\"\n        init_hidden()\n        \"\"\"\n        for param in self.parameters():\n            if param is not None:\n                a_param = param\n                break\n\n        for i, _ in enumerate(self.hidden):\n            if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):\n\n                if i==0:\n                    hidden_size = self.output_size\n                else:\n                    hidden_size = self.hidden_size\n\n                tens = a_param.data.new(bsz, hidden_size).zero_()\n                self.hidden[i] = Variable(tens, requires_grad=False)\n            \n        \n    def reset_hidden(self, bsz):\n        \"\"\"\n        reset_hidden()\n        \"\"\"\n        for i, _ in enumerate(self.hidden):\n            self.hidden[i] = None\n        self.init_hidden(bsz)\n\n    def detach_hidden(self):\n        \"\"\"\n        detach_hidden()\n        \"\"\"\n        for i, _ in enumerate(self.hidden):\n            if self.hidden[i] is None:\n                raise RuntimeError(\"Must initialize hidden state before you can detach it\")\n        for i, _ in enumerate(self.hidden):\n            self.hidden[i] = self.hidden[i].detach()\n        \n    def forward(self, input):\n        \"\"\"\n        forward()\n        if not inited or bsz has changed this will create hidden states\n        \"\"\"\n        self.init_hidden(input.size()[0])\n\n        hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden\n        self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)\n        if(self.n_hidden_states > 1):\n            self.hidden = list(self.hidden)\n        else:\n            self.hidden=[self.hidden]\n\n        if self.output_size != self.hidden_size:\n            self.hidden[0] = F.linear(self.hidden[0], self.w_ho)\n\n        return tuple(self.hidden)\n"
  },
  {
    "path": "KoSimCSE/apex/RNN/__init__.py",
    "content": "from .models import LSTM, GRU, ReLU, Tanh, mLSTM\n\n__all__ = ['models']\n"
  },
  {
    "path": "KoSimCSE/apex/RNN/cells.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .RNNBackend import RNNCell\n\nfrom torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend\n\nimport math \n\n\nclass mLSTMRNNCell(RNNCell):\n    \"\"\"\n    mLSTMRNNCell\n    \"\"\"\n\n    def __init__(self, input_size, hidden_size, bias = False, output_size = None):\n        gate_multiplier = 4\n        super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)\n\n        self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size))\n        self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size))\n\n        self.reset_parameters()\n\n    def forward(self, input):\n        \"\"\"\n        mLSTMRNNCell.forward()\n        \"\"\"\n        #if not inited or bsz has changed this will create hidden states\n        self.init_hidden(input.size()[0])\n\n        hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden\n\n        self.hidden = list(\n                           self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,\n                           b_ih=self.b_ih, b_hh=self.b_hh)\n        )\n        \n        if self.output_size != self.hidden_size:\n            self.hidden[0] = F.linear(self.hidden[0], self.w_ho)\n        return tuple(self.hidden)\n\n\n    def new_like(self, new_input_size=None):\n        if new_input_size is None:\n            new_input_size = self.input_size\n        \n        return type(self)(\n            new_input_size,\n            self.hidden_size,\n            self.bias,\n            self.output_size)\n\ndef mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):\n    \"\"\"\n    mLSTMCell\n    \"\"\"\n\n    if input.is_cuda:\n        igates = F.linear(input, w_ih)\n        m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)\n        hgates = F.linear(m, w_hh)\n\n        state = fusedBackend.LSTMFused.apply\n        return state(igates, hgates, hidden[1], b_ih, b_hh)\n\n    hx, cx = hidden\n    \n    m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)\n    gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)\n\n    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)\n\n    ingate = F.sigmoid(ingate)\n    forgetgate = F.sigmoid(forgetgate)\n    cellgate = F.tanh(cellgate)\n    outgate = F.sigmoid(outgate)\n    \n    cy = (forgetgate * cx) + (ingate * cellgate)\n    hy = outgate * F.tanh(cy)\n    \n    return hy, cy\n                                                                            \n"
  },
  {
    "path": "KoSimCSE/apex/RNN/models.py",
    "content": "import torch\n\nfrom torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell\n\nfrom .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell\nfrom .cells import mLSTMRNNCell, mLSTMCell\n\ndef toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):\n    \"\"\"\n    :class:`toRNNBackend`\n    \"\"\"\n\n    if bidirectional:\n        return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)\n    else:\n        return stackedRNN(inputRNN, num_layers, dropout = dropout)\n\n\ndef LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`LSTM`\n    \"\"\"\n    inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\ndef GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`GRU`\n    \"\"\"\n    inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\ndef ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`ReLU`\n    \"\"\"\n    inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\ndef Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`Tanh`\n    \"\"\"\n    inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n        \ndef mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):\n    \"\"\"\n    :class:`mLSTM`\n    \"\"\"\n    inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)\n    return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/__init__.py",
    "content": "# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten\nimport torch\nimport warnings\n\nif torch.distributed.is_available():\n    from . import parallel\n\nfrom . import amp\nfrom . import fp16_utils\n\n# For optimizers and normalization there is no Python fallback.\n# Absence of cuda backend is a hard error.\n# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda\n# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext\n# so they expect those backends to be available, but for some reason they actually aren't\n# available (for example because they built improperly in a way that isn't revealed until\n# load time) the error message is timely and visible.\nfrom . import optimizers\nfrom . import normalization\nfrom . import pyprof\n"
  },
  {
    "path": "KoSimCSE/apex/amp/README.md",
    "content": "# amp: Automatic Mixed Precision\n\n## Annotating User Functions\n\nNearly all PyTorch user code needs nothing more than the two steps\nabove to use amp. After all, custom layers are built out of simpler\nPyTorch components, and amp already can see those.\n\nHowever, any custom C++ or CUDA code is outside of amp's (default)\nview of things. For example, suppose I implemented a new recurrent\ncell called a \"forgetful recurrent unit\" that calls directly into a\nCUDA backend:\n\n```python\nfrom backend import FRUBackend\n\ndef fru(input, hidden, weight, bias):\n    # call to CUDA code\n    FRUBackend(input, hidden, weight, bias)\n```\n\nIn this case, it is possible to get a runtime type mismatch. For\nexample, you might have `input` in fp16, and `weight` in fp32, and amp\ndoesn't have the visibility to insert an appropriate cast.\n\namp exposes two ways to handle \"invisible\" backend code: function\nannotations and explicit registration.\n\n#### Function annotation\n\nThe first way to handle backend code is a set of function annotations:\n\n- `@amp.half_function`\n- `@amp.float_function`\n- `@amp.promote_function`\n\nThese correspond to:\n\n- Cast all arguments to fp16\n- Cast all argumnets fo fp32\n- If there are any type mismatches, cast everything to the widest type\n\nIn our example, we believe that the FRU unit is fp16-safe and will get\nperformance gains from casting its arguments to fp16, so we write:\n\n```python\n@amp.half_function\ndef fru(input, hidden, weight, bias):\n    #...\n```\n\n#### Explicit registration\n\nThe other way to handle backend code is with explicit function\nregistration:\n\n- `amp.register_half_function(module, function_name)`\n- `amp.register_float_function(module, function_name)`\n- `amp.register_promote_function(module, function_name)`\n\nWhen using this API, `module` is the containing class or module for\nthe function, and `function_name` is the _string_ name of the\nfunction. Note that the function must be registered before the call to\n`amp.initalize()`.\n\nFor our FRU unit, we can register the backend function directly:\n\n```python\nimport backend\n\namp.register_half_function(backend, 'FRUBackend')\n```\n"
  },
  {
    "path": "KoSimCSE/apex/amp/__init__.py",
    "content": "from .amp import init, half_function, float_function, promote_function,\\\n    register_half_function, register_float_function, register_promote_function\nfrom .handle import scale_loss, disable_casts\nfrom .frontend import initialize, state_dict, load_state_dict\nfrom ._amp_state import master_params, _amp_state\n"
  },
  {
    "path": "KoSimCSE/apex/amp/__version__.py",
    "content": "VERSION = (0, 1, 0)\n__version__ = '.'.join(map(str, VERSION))\n"
  },
  {
    "path": "KoSimCSE/apex/amp/_amp_state.py",
    "content": "# This is a \"header object\" that allows different amp modules to communicate.\n# I'm a C++ guy, not a python guy.  I decided this approach because it seemed most C++-like.\n# But apparently it's ok:\n# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm\nimport os\nimport torch\n\nTORCH_MAJOR = int(torch.__version__.split('.')[0])\nTORCH_MINOR = int(torch.__version__.split('.')[1])\n\n\nif TORCH_MAJOR == 1 and TORCH_MINOR < 8:\n    from torch._six import container_abcs\nelse:\n    import collections.abc as container_abcs\n\n\nclass AmpState(object):\n    def __init__(self):\n        self.hard_override=False\n        self.allow_incoming_model_not_fp32 = False\n        self.verbosity=1\n\n\n# Attribute stash.  Could also just stash things as global module attributes.\n_amp_state = AmpState()\n\n\ndef warn_or_err(msg):\n    if _amp_state.hard_override:\n        print(\"Warning:  \" + msg)\n    else:\n        raise RuntimeError(msg)\n        # I'm not sure if allowing hard_override is a good idea.\n        # + \"  If you're sure you know what you're doing, supply \" +\n        #                    \"hard_override=True to amp.initialize.\")\n\n\ndef maybe_print(msg, rank0=False):\n    distributed = torch.distributed.is_available() and \\\n        torch.distributed.is_initialized() and \\\n        torch.distributed.get_world_size() > 1\n    if _amp_state.verbosity > 0:\n        if rank0:\n            if distributed:\n                if torch.distributed.get_rank() == 0:\n                    print(msg)\n            else:\n                print(msg)\n        else:\n            print(msg)\n\n\n# def iter_params(param_groups):\n#     for group in param_groups:\n#         for p in group['params']:\n#             yield p\n\n\ndef master_params(optimizer):\n    \"\"\"\n    Generator expression that iterates over the params owned by ``optimizer``.\n\n    Args:\n        optimizer: An optimizer previously returned from ``amp.initialize``.\n    \"\"\"\n    for group in optimizer.param_groups:\n        for p in group['params']:\n            yield p\n"
  },
  {
    "path": "KoSimCSE/apex/amp/_initialize.py",
    "content": "import torch\nfrom torch._six import string_classes\nimport functools\nimport numpy as np\nimport sys\nfrom types import MethodType\nimport warnings\nfrom ._amp_state import _amp_state, warn_or_err, container_abcs\nfrom .handle import disable_casts\nfrom .scaler import LossScaler\nfrom ._process_optimizer import _process_optimizer\nfrom apex.fp16_utils import convert_network\nfrom ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general\nfrom ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused\n\nif torch.distributed.is_available():\n    from ..parallel import DistributedDataParallel as apex_DDP\n    from ..parallel.LARC import LARC\n\n\ndef to_type(dtype, t):\n    if isinstance(t, torch.Tensor):\n        if not t.is_cuda:\n            # This should not be a hard error, since it may be legitimate.\n            warnings.warn(\"An input tensor was not cuda.\")\n        # GANs require this.\n        # if t.requires_grad:\n        #     warn_or_err(\"input data requires grad.  Since input data is not a model parameter,\\n\"\n        #         \"its gradients will not be properly allreduced by DDP.\")\n        if t.is_floating_point():\n            return t.to(dtype)\n        return t\n    else:\n        # Trust the user's custom batch type, that's all I can do here.\n        return t.to(dtype)\n\n\n# Modified from torch.optim.optimizer.py.  This is a bit more general than casted_args in utils.py.\ndef applier(value, fn):\n    if isinstance(value, torch.Tensor):\n        return fn(value)\n    elif isinstance(value, string_classes):\n        return value\n    elif isinstance(value, np.ndarray):\n        return value\n    elif hasattr(value, \"to\"): # Allow handling of custom batch classes\n        return fn(value)\n    elif isinstance(value, container_abcs.Mapping):\n        return {applier(k, fn) : applier(v, fn) for k, v in value.items()}\n    elif isinstance(value, container_abcs.Iterable):\n        return type(value)(applier(v, fn) for v in value)\n    else:\n        # Do I want this to fire off even if someone chooses to pass something ordinary like\n        # an int or float?  May be more annoying than it's worth.\n        # print(\"Warning:  unrecognized type in applier.  If your input data is a custom class, \"\n        #     \"provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. \"\n        #     \"Amp will check for your custom to() and invoke it to cast the batch's \"\n        #     \"floating-point Tensors to the appropriate type. \"\n        #     \"Also, if your data is a custom class, it is your responsibility to ensure that \"\n        #     \"any Tensors you want to be cuda are already cuda.\"\n        return value\n\n\ndef check_models(models):\n    for model in models:\n        parallel_type = None\n        if isinstance(model, torch.nn.parallel.DistributedDataParallel):\n            parallel_type = \"torch.nn.parallel.DistributedDataParallel\"\n        if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):\n            parallel_type = \"apex.parallel.DistributedDataParallel\"\n        if isinstance(model, torch.nn.parallel.DataParallel):\n            parallel_type = \"torch.nn.parallel.DataParallel\"\n        if parallel_type is not None:\n            raise RuntimeError(\"Incoming model is an instance of {}. \".format(parallel_type) +\n                \"Parallel wrappers should only be applied to the model(s) AFTER \\n\"\n                \"the model(s) have been returned from amp.initialize.\")\n\n\ndef check_params_fp32(models):\n    for model in models:\n        for name, param in model.named_parameters():\n            if param.is_floating_point():\n                if 'Half' in param.type():\n                    warn_or_err(\"Found param {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you do not need to call .half() on your model\\n\"\n                        \"before passing it, no matter what optimization level you choose.\".format(\n                        name, param.type()))\n                elif not param.is_cuda:\n                    warn_or_err(\"Found param {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you need to provide a model with parameters\\n\"\n                        \"located on a CUDA device before passing it no matter what optimization level\\n\"\n                        \"you chose. Use model.to('cuda') to use the default device.\".format(\n                        name, param.type()))\n\n        # Backward compatibility for PyTorch 0.4\n        if hasattr(model, 'named_buffers'):\n            buf_iter = model.named_buffers()\n        else:\n            buf_iter = model._buffers\n        for obj in buf_iter:\n            if type(obj)==tuple:\n                name, buf = obj\n            else:\n                name, buf = obj, buf_iter[obj]\n            if buf.is_floating_point():\n                if 'Half' in buf.type():\n                    warn_or_err(\"Found buffer {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you do not need to call .half() on your model\\n\"\n                        \"before passing it, no matter what optimization level you choose.\".format(\n                        name, buf.type()))\n                elif not buf.is_cuda:\n                    warn_or_err(\"Found buffer {} with type {}, expected torch.cuda.FloatTensor.\\n\"\n                        \"When using amp.initialize, you need to provide a model with buffers\\n\"\n                        \"located on a CUDA device before passing it no matter what optimization level\\n\"\n                        \"you chose. Use model.to('cuda') to use the default device.\".format(\n                        name, buf.type()))\n\n\ndef check_optimizers(optimizers):\n    for optim in optimizers:\n        bad_optim_type = None\n        if isinstance(optim, FP16_Optimizer_general):\n            bad_optim_type = \"apex.fp16_utils.FP16_Optimizer\"\n        if isinstance(optim, FP16_Optimizer_for_fused):\n            bad_optim_type = \"apex.optimizers.FP16_Optimizer\"\n        if bad_optim_type is not None:\n            raise RuntimeError(\"An incoming optimizer is an instance of {}. \".format(bad_optim_type) +\n                               \"The optimizer(s) passed to amp.initialize() must be bare \\n\"\n                               \"instances of either ordinary Pytorch optimizers, or Apex fused \\n\"\n                               \"optimizers.\\n\")\n\n\nclass O2StateDictHook(object):\n    def __init__(self, fn):\n        self.fn = fn\n\n    def __call__(self, module, state_dict, prefix, local_metadata):\n        for key in state_dict:\n            param = state_dict[key]\n            if 'Half' in param.type():\n                param = param.to(torch.float32)\n                state_dict[key] = param\n\n\ndef _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):\n    from .amp import init as amp_init\n\n    optimizers_was_list = False\n    if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):\n        optimizers = [optimizers]\n    elif optimizers is None:\n        optimizers = []\n    elif isinstance(optimizers, list):\n        optimizers_was_list = True\n        check_optimizers(optimizers)\n    else:\n        check_optimizers([optimizers])\n        raise TypeError(\"optimizers must be either a single optimizer or a list of optimizers.\")\n\n    if isinstance(models, torch.nn.Module):\n        models_was_list = False\n        models = [models]\n    elif isinstance(models, list):\n        models_was_list = True\n    else:\n        raise TypeError(\"models must be either a single model or a list of models.\")\n\n    check_models(models)\n\n    if not _amp_state.allow_incoming_model_not_fp32:\n        check_params_fp32(models)\n\n    # In the future, when FP16_Optimizer can be deprecated and master weights can\n    # become an attribute, remember to stash master weights before casting the model.\n\n    if properties.cast_model_type:\n        if properties.keep_batchnorm_fp32:\n            for model in models:\n                convert_network(model, properties.cast_model_type)\n        else:\n            for model in models:\n                model.to(properties.cast_model_type)\n\n        input_caster = functools.partial(to_type, properties.cast_model_type)\n        if cast_model_outputs is not None:\n            output_caster = functools.partial(to_type, cast_model_outputs)\n        else:\n            output_caster = functools.partial(to_type, torch.float32)\n\n        for model in models:\n            # Patch the forward method to cast incoming data to the correct type, and\n            # outgoing data to float32, so \"the user never needs to call .half().\"\n            # I like writing things explicitly more than decorators.\n            def patch_forward(old_fwd):\n                def new_fwd(*args, **kwargs):\n                    output = old_fwd(*applier(args, input_caster),\n                                     **applier(kwargs, input_caster))\n                    return applier(output, output_caster)\n                return new_fwd\n\n            model.forward = patch_forward(model.forward)\n\n        # State dict trick to recast any preexisting per-param state tensors\n        for optimizer in optimizers:\n            optimizer.load_state_dict(optimizer.state_dict())\n\n        # patch model.state_dict() to return float32 params\n        for model in models:\n            for module in model.modules():\n                module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))\n\n    elif cast_model_outputs is not None:\n        output_caster = functools.partial(to_type, cast_model_outputs)\n\n        for model in models:\n            def patch_forward(old_fwd):\n                def new_fwd(*args, **kwargs):\n                    output = old_fwd(*args, **kwargs)\n                    return applier(output, output_caster)\n                return new_fwd\n\n            model.forward = patch_forward(model.forward)\n\n    for i, optimizer in enumerate(optimizers):\n        optimizers[i] = _process_optimizer(optimizer, properties)\n\n    _amp_state.loss_scalers = []\n    for _ in range(num_losses):\n        _amp_state.loss_scalers.append(LossScaler(properties.loss_scale,\n                                                  min_loss_scale=_amp_state.min_loss_scale,\n                                                  max_loss_scale=_amp_state.max_loss_scale))\n\n    if properties.patch_torch_functions:\n        # handle is unused here. It's accessible later through a global value anyway.\n        handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))\n        for optimizer in optimizers:\n            # Disable Amp casting for the optimizer step, because it should only be\n            # applied to FP32 master params anyway.\n            def patch_step(old_step):\n                def new_step(self, *args, **kwargs):\n                    with disable_casts():\n                        output = old_step(*args, **kwargs)\n                    return output\n                return new_step\n\n            optimizer.step = MethodType(patch_step(optimizer.step), optimizer)\n\n    if optimizers_was_list:\n        if models_was_list:\n            return models, optimizers\n        else:\n            return models[0], optimizers\n    else:\n        if models_was_list:\n            if len(optimizers) == 0:\n                return models\n            else:\n                return models, optimizers[0]\n        else:\n            if len(optimizers) == 0:\n                return models[0]\n            else:\n                return models[0], optimizers[0]\n"
  },
  {
    "path": "KoSimCSE/apex/amp/_process_optimizer.py",
    "content": "import types\nfrom ..fp16_utils import master_params_to_model_params\nfrom ..multi_tensor_apply import multi_tensor_applier\nfrom ._amp_state import maybe_print\nimport torch\nfrom ..optimizers import FusedSGD\n\n\nclass AmpOptimizerState(object):\n    def __init__(self):\n        pass\n\n\ndef _master_params_to_model_params(self):\n    stash = self._amp_stash\n    if multi_tensor_applier.available:\n        if len(stash.all_fp16_params) > 0:\n            multi_tensor_applier(\n                stash.multi_tensor_scale,\n                stash.dummy_overflow_buf,\n                [stash.all_fp32_from_fp16_params, stash.all_fp16_params],\n                1.0)\n    else:\n        for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):\n            master_params_to_model_params(fp16_group, fp32_from_fp16_group)\n\n\ndef lazy_init_with_master_weights(self):\n        stash = self._amp_stash\n        stash.fp16_groups = []\n        stash.fp32_from_fp16_groups = []\n        stash.fp32_from_fp32_groups = []\n        for i, param_group in enumerate(self.param_groups):\n            # maybe_print(\"FP16_Optimizer processing param group {}:\".format(i))\n            fp16_params_this_group = []\n            fp32_params_this_group = []\n            fp32_from_fp16_params_this_group = []\n            for i, param in enumerate(param_group['params']):\n                if param.requires_grad:\n                    if param.type() == 'torch.cuda.HalfTensor':\n                        # maybe_print(\"FP16_Optimizer received torch.cuda.HalfTensor with {}\"\n                        #             .format(param.size()))\n                        fp16_params_this_group.append(param)\n                        master_param = param.detach().clone().float()\n                        master_param.requires_grad = True\n                        param_group['params'][i] = master_param\n                        fp32_from_fp16_params_this_group.append(master_param)\n                        # Reset existing state dict key to the new master param.\n                        # We still need to recast per-param state tensors, if any, to FP32.\n                        if param in self.state:\n                           self.state[master_param] = self.state.pop(param)\n                    elif param.type() == 'torch.cuda.FloatTensor':\n                        # maybe_print(\"FP16_Optimizer received torch.cuda.FloatTensor with {}\"\n                        #             .format(param.size()))\n                        fp32_params_this_group.append(param)\n                        param_group['params'][i] = param\n                    else:\n                        raise TypeError(\"Optimizer's parameters must be either \"\n                                        \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                        \"Received {}\".format(param.type()))\n\n            stash.fp16_groups.append(fp16_params_this_group)\n            stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)\n            stash.fp32_from_fp32_groups.append(fp32_params_this_group)\n\n        stash.all_fp16_params = []\n        for group in stash.fp16_groups:\n            stash.all_fp16_params += group\n\n        stash.all_fp32_from_fp16_params = []\n        for group in stash.fp32_from_fp16_groups:\n            stash.all_fp32_from_fp16_params += group\n\n        stash.all_fp32_from_fp32_params = []\n        for group in stash.fp32_from_fp32_groups:\n            stash.all_fp32_from_fp32_params += group\n\n        # all_fp16_grad_stash is only needed for fused optimizers.\n        stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]\n        # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]\n        stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]\n\n        for param in stash.all_fp32_from_fp16_params:\n            param.grad = None\n\n        for param in stash.all_fp32_from_fp32_params:\n            param.grad = None\n\n        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors\n        self.load_state_dict(self.state_dict())\n\n\ndef post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):\n        grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0\n\n        # not much to do if scale == 1.0 and static scaling\n        if scaler.loss_scale() == 1.0 and not scaler.dynamic:\n            # Clear the stash.\n            for i in range(len(stashed_grads)):\n                stashed_grads[i] = None\n            return\n        \n        if scale_override is not None:\n            grads_have_scale, stashed_have_scale, out_scale = scale_override\n\n        # This is a lot of python overhead...\n        grads_needing_unscale = []\n        grads_needing_unscale_with_stash = []\n        stashed = []\n        for param, stashed_grad in zip(params, stashed_grads):\n            if param.grad is None and stashed_grad is not None:\n                param.grad = stashed_grad\n            elif param.grad is not None and stashed_grad is None:\n                grads_needing_unscale.append(param.grad)\n            elif param.grad is not None and stashed_grad is not None:\n                grads_needing_unscale_with_stash.append(param.grad)\n                stashed.append(stashed_grad)\n            else: # param.grad is None and stashed_grad is None\n                continue\n\n        # unscale() implements grads*(1/scale), so \"scale\" should be grads_have_scale/out_scale.\n        if len(grads_needing_unscale) > 0:\n            scaler.unscale(\n                grads_needing_unscale,\n                grads_needing_unscale,\n                None, # unused_scale, currently present to avoid API breakage elsewhere\n                models_are_masters=True,\n                scale_override=grads_have_scale/out_scale)\n\n        if len(grads_needing_unscale_with_stash) > 0:\n            scaler.unscale_with_stashed(\n                grads_needing_unscale_with_stash,\n                stashed,\n                grads_needing_unscale_with_stash,\n                scale_override=(grads_have_scale, stashed_have_scale, out_scale))\n\n        # Clear the stash.\n        for i in range(len(stashed_grads)):\n            stashed_grads[i] = None\n\n\ndef prepare_backward_with_master_weights(self):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    for i, param in enumerate(stash.all_fp16_params):\n        # Set up to leverage grad copy elision.\n        # This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.\n        param.grad = None\n\n    # for i, param in enumerate(stash.all_fp32_from_fp16_params):\n    #     stash.all_fp32_from_fp16_grad_stash[i] = param.grad\n\n    for i, param in enumerate(stash.all_fp32_from_fp32_params):\n        stash.all_fp32_from_fp32_grad_stash[i] = param.grad\n        # Set up to leverage grad copy elision:\n        param.grad = None\n\n\ndef post_backward_with_master_weights(self, scaler):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    # This is a lot of python overhead...\n    fp16_grads_needing_unscale = []\n    new_fp32_grads = []\n    fp16_grads_needing_unscale_with_stash = []\n    preexisting_fp32_grads = []\n    for fp16_param, fp32_param in zip(stash.all_fp16_params,\n                                      stash.all_fp32_from_fp16_params):\n        if fp16_param.grad is None and fp32_param.grad is not None:\n            continue\n        elif fp16_param.grad is not None and fp32_param.grad is None:\n            fp32_param.grad = torch.empty_like(fp32_param)\n            fp16_grads_needing_unscale.append(fp16_param.grad)\n            new_fp32_grads.append(fp32_param.grad)\n        elif fp16_param.grad is not None and fp32_param.grad is not None:\n            fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)\n            preexisting_fp32_grads.append(fp32_param.grad)\n        else: # fp16_param.grad is None and fp32_param.grad is None:\n            continue\n\n    if len(fp16_grads_needing_unscale) > 0:\n        scaler.unscale(\n            fp16_grads_needing_unscale,\n            new_fp32_grads,\n            scaler.loss_scale(),\n            models_are_masters=False)\n\n    if len(fp16_grads_needing_unscale_with_stash) > 0:\n        scaler.unscale_with_stashed(\n            fp16_grads_needing_unscale_with_stash,\n            preexisting_fp32_grads,\n            preexisting_fp32_grads)\n\n    # fp32 params can be treated as they would be in the \"no_master_weights\" case.\n    post_backward_models_are_masters(\n        scaler,\n        stash.all_fp32_from_fp32_params,\n        stash.all_fp32_from_fp32_grad_stash)\n\n\ndef lazy_init_no_master_weights(self):\n    stash = self._amp_stash\n    stash.all_fp16_params = []\n    stash.all_fp32_params = []\n    for i, param_group in enumerate(self.param_groups):\n        for i, param in enumerate(param_group['params']):\n            if param.type() == 'torch.cuda.HalfTensor':\n                stash.all_fp16_params.append(param)\n            elif param.type() == 'torch.cuda.FloatTensor':\n                stash.all_fp32_params.append(param)\n            else:\n                raise TypeError(\"Optimizer's parameters must be either \"\n                                \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                \"Received {}\".format(param.type()))\n\n    stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]\n    stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]\n\n\ndef prepare_backward_no_master_weights(self):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    for i, param in enumerate(stash.all_fp16_params):\n        stash.all_fp16_grad_stash[i] = param.grad\n        # Set up to leverage grad copy elision:\n        param.grad = None\n\n    for i, param in enumerate(stash.all_fp32_params):\n        stash.all_fp32_grad_stash[i] = param.grad\n        # Set up to leverage grad copy elision:\n        param.grad = None\n\n\ndef post_backward_no_master_weights(self, scaler):\n    stash = self._amp_stash\n\n    self._amp_lazy_init()\n\n    split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),\n             (stash.all_fp32_params, stash.all_fp32_grad_stash))\n\n    for params, stashed_grads in split_types:\n        post_backward_models_are_masters(scaler, params, stashed_grads)\n\n\n#####################################################################################\n# FusedSGD versions\n#####################################################################################\n\n# FusedSGD never explicitly materializes the fp32 gradients for \"fp32 from fp16\" master params\n# outside the kernel, so we must accumulate directly into the model grads.\ndef prepare_backward_with_master_weights_FusedSGD(self):\n    if self.materialize_master_grads:\n        prepare_backward_with_master_weights(self)\n    else:\n        stash = self._amp_stash\n\n        self._amp_lazy_init()\n\n        for i, param in enumerate(stash.all_fp16_params):\n            stash.all_fp16_grad_stash[i] = param.grad\n            # Set up to leverage grad copy elision:\n            param.grad = None\n\n        for i, param in enumerate(stash.all_fp32_from_fp32_params):\n            stash.all_fp32_from_fp32_grad_stash[i] = param.grad\n            # Set up to leverage grad copy elision:\n            param.grad = None\n\n\ndef post_backward_with_master_weights_FusedSGD(self, scaler):\n    if self.materialize_master_grads:\n        post_backward_with_master_weights(self, scaler)\n    else:\n        stash = self._amp_stash\n\n        self._amp_lazy_init()\n\n        grads_have_scale = scaler.loss_scale()\n        stashed_have_scale = self.most_recent_scale\n        out_scale = grads_have_scale\n        if self.scale_set_by_backward:\n            out_scale = min(grads_have_scale, self.most_recent_scale)\n\n        split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),\n                 (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))\n\n\n        # unscale_with_stashed() implements grads*1/scale + stashed_grads*1.\n        # stashed_grads are scaled by self.most_recent_scale.\n        for params, stashed_grads in split_types:\n            post_backward_models_are_masters(scaler, params, stashed_grads,\n                                             (grads_have_scale, stashed_have_scale, out_scale))\n\n        self.most_recent_scale = out_scale\n        self.scale_set_by_backward = True\n\n\ndef prepare_backward_no_master_weights_FusedSGD(self):\n    prepare_backward_no_master_weights(self)\n\n\ndef post_backward_no_master_weights_FusedSGD(self, scaler):\n    post_backward_no_master_weights(self, scaler)\n\n\ndef _amp_lazy_init(self):\n    stash = self._amp_stash\n\n    if not stash.lazy_init_called:\n        self._lazy_init_maybe_master_weights()\n        stash.lazy_init_called = True\n\n\ndef _process_optimizer(optimizer, properties):\n    if hasattr(optimizer, \"_amp_stash\"):\n        raise RuntimeError(\"A given optimizer should only be passed through amp.initialize once.\")\n    else:\n        optimizer._amp_stash = AmpOptimizerState()\n\n    optimizer._amp_stash.lazy_init_called = False\n    optimizer._amp_stash.already_patched = False\n    optimizer._amp_stash.params_have_scaled_gradients = False\n\n    for name in (\"_lazy_init_maybe_master_weights\",\n                 \"_master_params_to_model_params\",\n                 \"_prepare_amp_backward\",\n                 \"_post_amp_backward\",\n                 \"_amp_lazy_init\"):\n        if hasattr(optimizer, name):\n            raise RuntimeError(\"Incoming optimizer already has {} defined.\".format(name))\n\n    # TODO:  Centralize exposure and import error checking for the C backend.\n    if multi_tensor_applier.available:\n        import amp_C\n        optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale\n        optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n        optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);\n\n    if properties.master_weights:\n        optimizer._lazy_init_maybe_master_weights = types.MethodType(\n            lazy_init_with_master_weights, optimizer)\n\n        optimizer._master_params_to_model_params = types.MethodType(\n            _master_params_to_model_params, optimizer)\n\n        old_step = optimizer.step\n        def new_step(self, closure=None):\n            if closure is not None:\n                raise RuntimeError(\"Currently, Amp does not support closure use with optimizers.\")\n            retval = old_step()\n            if not isinstance(self, FusedSGD):\n                self._master_params_to_model_params()\n            # Clear the master grads that wouldn't be zeroed by model.zero_grad()\n            for param in self._amp_stash.all_fp32_from_fp16_params:\n                param.grad = None\n            return retval\n        optimizer.step = types.MethodType(new_step, optimizer)\n\n        old_zero_grad = optimizer.zero_grad\n        def new_zero_grad(self):\n            stash = self._amp_stash\n            self._amp_lazy_init()\n            # Zero the model grads.\n            for param in stash.all_fp16_params:\n                if param.grad is not None:\n                    param.grad.detach_()\n                    param.grad.zero_()\n            for param in stash.all_fp32_from_fp32_params:\n                if param.grad is not None:\n                    param.grad.detach_()\n                    param.grad.zero_()\n            # Clear the master grads that are independent of model grads\n            for param in self._amp_stash.all_fp32_from_fp16_params:\n                param.grad = None\n        optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)\n\n        if isinstance(optimizer, FusedSGD):\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_with_master_weights_FusedSGD, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_with_master_weights_FusedSGD, optimizer)\n        else:\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_with_master_weights, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_with_master_weights, optimizer)\n    else:\n        optimizer._lazy_init_maybe_master_weights = types.MethodType(\n            lazy_init_no_master_weights, optimizer)\n\n        if isinstance(optimizer, FusedSGD):\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_no_master_weights_FusedSGD, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_no_master_weights_FusedSGD, optimizer)\n        else:\n            optimizer._prepare_amp_backward = types.MethodType(\n                prepare_backward_no_master_weights, optimizer)\n            optimizer._post_amp_backward = types.MethodType(\n                post_backward_no_master_weights, optimizer)\n\n    optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)\n\n    old_add_param_group = optimizer.add_param_group\n\n    def new_add_param_group(self, new_group):\n        stash = self._amp_stash\n\n        if not stash.lazy_init_called:\n            self._lazy_init_maybe_master_weights()\n            stash.lazy_init_called = True\n\n        assert isinstance(new_group, dict), \"param group must be a dict\"\n\n        new_params = new_group['params']\n        if isinstance(new_params, torch.Tensor):\n            new_group['params'] = [new_params]\n        elif isinstance(new_params, set):\n            raise TypeError('optimizer parameters need to be organized in ordered collections, but '\n                            'the ordering of tensors in sets will change between runs. Please use a list instead.')\n        else:\n            new_group['params'] = list(new_params)\n\n        if properties.master_weights:\n            # Mutate new_group in-place to use FP32 master params\n            fp16_params_this_group = []\n            fp32_params_this_group = []\n            fp32_from_fp16_params_this_group = []\n            for i, param in enumerate(new_group['params']):\n                if param.requires_grad:\n                    if param.type() == 'torch.cuda.HalfTensor':\n                        fp16_params_this_group.append(param)\n                        master_param = param.detach().clone().float()\n                        master_param.requires_grad = True\n                        new_group['params'][i] = master_param\n                        fp32_from_fp16_params_this_group.append(master_param)\n                    elif param.type() == 'torch.cuda.FloatTensor':\n                        fp32_params_this_group.append(param)\n                        new_group['params'][i] = param\n                    else:\n                        raise TypeError(\"Optimizer's parameters must be either \"\n                                        \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                        \"Received {}\".format(param.type()))\n\n            stash.fp16_groups.append(fp16_params_this_group)\n            stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)\n            stash.fp32_from_fp32_groups.append(fp32_params_this_group)\n\n            stash.all_fp16_params += fp16_params_this_group\n            stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group\n            stash.all_fp32_from_fp32_params += fp32_params_this_group\n\n            # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]\n            stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]\n\n            # It should be ok to let params be added with existing .grad attributes.\n            # for param in fp16_params_this_group:\n            #     param.grad = None\n\n            # for param in fp32_from_fp16_params_this_group:\n            #     param.grad = None\n\n            # for param in stash.fp32_params_this_group:\n            #     param.grad = None\n        else:\n            for param in new_group['params']:\n                if param.type() == 'torch.cuda.HalfTensor':\n                    stash.all_fp16_params.append(param)\n                    stash.all_fp16_grad_stash.append(None)\n                elif param.type() == 'torch.cuda.FloatTensor':\n                    stash.all_fp32_params.append(param)\n                    stash.all_fp32_grad_stash.append(None)\n                else:\n                    raise TypeError(\"Optimizer's parameters must be either \"\n                                    \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"\n                                    \"Received {}\".format(param.type()))\n\n        old_add_param_group(new_group)\n\n    optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)\n\n    return optimizer\n"
  },
  {
    "path": "KoSimCSE/apex/amp/amp.py",
    "content": "from . import compat, rnn_compat, utils, wrap\nfrom .handle import AmpHandle, NoOpHandle\nfrom .lists import functional_overrides, torch_overrides, tensor_overrides\nfrom ._amp_state import _amp_state\nfrom .frontend import *\n\nimport functools\nimport itertools\n\nimport torch\n\n\n_DECORATOR_HANDLE = None\n_USER_CAST_REGISTRY = set()\n_USER_PROMOTE_REGISTRY = set()\n\n\ndef _decorator_helper(orig_fn, cast_fn, wrap_fn):\n    def wrapper(*args, **kwargs):\n        handle = _DECORATOR_HANDLE\n        if handle is None or not handle.is_active():\n            return orig_fn(*args, **kwargs)\n        inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,\n                                  handle.verbose)\n        return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)\n    return wrapper\n\n\n# Decorator form\ndef half_function(fn):\n    wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)\n    return _decorator_helper(fn, utils.maybe_half, wrap_fn)\n\n\ndef float_function(fn):\n    wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)\n    return _decorator_helper(fn, utils.maybe_float, wrap_fn)\n\n\ndef promote_function(fn):\n    wrap_fn = functools.partial(wrap.make_promote_wrapper)\n    return _decorator_helper(fn, utils.maybe_float, wrap_fn)\n\n\n# Registry form\ndef register_half_function(module, name):\n    if not hasattr(module, name):\n        raise ValueError('No function named {} in module {}.'.format(\n            name, module))\n    _USER_CAST_REGISTRY.add((module, name, utils.maybe_half))\n\n\ndef register_float_function(module, name):\n    if not hasattr(module, name):\n        raise ValueError('No function named {} in module {}.'.format(\n            name, module))\n    _USER_CAST_REGISTRY.add((module, name, utils.maybe_float))\n\n\ndef register_promote_function(module, name):\n    if not hasattr(module, name):\n        raise ValueError('No function named {} in module {}.'.format(\n            name, module))\n    _USER_PROMOTE_REGISTRY.add((module, name))\n\n\n# Top-level function to insert _all_ the hooks.\ndef init(enabled=True, loss_scale=\"dynamic\", enable_caching=True, verbose=False, allow_banned=False):\n    global _DECORATOR_HANDLE\n\n    if not enabled:\n        handle = NoOpHandle()\n        _DECORATOR_HANDLE = handle\n        return handle\n\n    handle = AmpHandle(loss_scale, enable_caching, verbose)\n\n    # 0) Force-{fp16, fp32} for user-annotated functions\n    for mod, fn, cast_fn in _USER_CAST_REGISTRY:\n        try_caching = (cast_fn == utils.maybe_half)\n        wrap.cached_cast(mod, fn, cast_fn, handle,\n                         try_caching, verbose)\n    _USER_CAST_REGISTRY.clear()\n\n    # 0.5) Force-promote for user-annotated functions\n    for mod, fn in _USER_PROMOTE_REGISTRY:\n        wrap.promote(mod, fn, handle, verbose)\n    _USER_PROMOTE_REGISTRY.clear()\n\n    # 1) Force-{fp16, fp32} on white- / black-list functions\n    override_modules = [functional_overrides,\n                        torch_overrides,\n                        tensor_overrides]\n    cast_table = [('FP16_FUNCS', utils.maybe_half),\n                  ('FP32_FUNCS', utils.maybe_float)]\n    for module, (list_name, cast_fn) in itertools.product(override_modules,\n                                                          cast_table):\n        for fn in getattr(module, list_name):\n            try_caching = (cast_fn == utils.maybe_half)\n            wrap.cached_cast(module.MODULE, fn, cast_fn, handle,\n                             try_caching, verbose)\n\n    # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist\n    #      methods on FloatTensor, since they're distinct types.\n    if compat.tensor_is_float_tensor():\n        for fn in tensor_overrides.FP16_FUNCS:\n            wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,\n                             handle, try_caching=True, verbose=verbose)\n        for fn in tensor_overrides.FP32_FUNCS:\n            wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,\n                             handle, try_caching=False, verbose=verbose)\n\n    # 2) Enable type-promotion on multi-arg functions and methods.\n    #    NB: special handling for sequence fns (e.g. `torch.cat`).\n    promote_modules = [torch_overrides, tensor_overrides]\n    promote_table = [('CASTS', wrap.promote),\n                     ('SEQUENCE_CASTS', wrap.sequence_promote)]\n    for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,\n                                                                  promote_table):\n        for fn in getattr(promote_mod, list_name):\n            promote_fn(promote_mod.MODULE, fn, handle, verbose)\n\n    # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types\n    if compat.tensor_is_float_tensor():\n        for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,\n                                                               torch.cuda.HalfTensor],\n                                                              promote_table):\n            for fn in getattr(tensor_overrides, list_name):\n                promote_fn(cls, fn, handle, verbose)\n\n    # 3) For any in-place version of a blacklist function, error if any input is fp16.\n    #    NB: this is overly conservative.\n    for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):\n        wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)\n\n    # 3.5) For any in-place blacklist method, error if called on fp16 tensor\n    for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):\n        wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)\n        if compat.tensor_is_float_tensor():\n            wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)\n\n    # 4) For other in-place methods, match the type of self tensor\n    for fn in utils.as_inplace(itertools.chain(\n            tensor_overrides.FP16_FUNCS,\n            tensor_overrides.CASTS)):\n        wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)\n        if compat.tensor_is_float_tensor():\n            wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)\n            wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)\n\n    # 5) RNNs + RNN cells are whitelisted specially\n    if rnn_compat.has_old_rnns():\n        wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)\n    if not rnn_compat.has_old_rnns():\n        # Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.\n        torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()\n        # Wrap all the rnns\n        for x in rnn_compat.RNN_NAMES:\n            wrap.new_rnn_cast(x.upper(), handle, verbose)\n\n    # Wrap all the RNN cells\n    rnn_compat.whitelist_rnn_cells(handle, verbose)\n\n    # 6) Place error+print message on banned functions.\n    #    Or, if allow_banned, then cast to FP32.\n    for fn, err_msg in functional_overrides.BANNED_FUNCS:\n        if allow_banned:\n            wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,\n                             handle, try_caching=True, verbose=verbose)\n        else:\n            wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)\n\n    _DECORATOR_HANDLE = handle\n\n    _amp_state.handle = handle\n\n    return handle\n"
  },
  {
    "path": "KoSimCSE/apex/amp/compat.py",
    "content": "import torch\n\n# True for post-0.4, when Variables/Tensors merged.\ndef variable_is_tensor():\n    v = torch.autograd.Variable()\n    return isinstance(v, torch.Tensor)\n\ndef tensor_is_variable():\n    x = torch.Tensor()\n    return type(x) == torch.autograd.Variable\n\n# False for post-0.4\ndef tensor_is_float_tensor():\n    x = torch.Tensor()\n    return type(x) == torch.FloatTensor\n\n# Akin to `torch.is_tensor`, but returns True for Variable\n# objects in pre-0.4.\ndef is_tensor_like(x):\n    return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)\n\n# Wraps `torch.is_floating_point` if present, otherwise checks\n# the suffix of `x.type()`.\ndef is_floating_point(x):\n    if hasattr(torch, 'is_floating_point'):\n        return torch.is_floating_point(x)\n    try:\n        torch_type = x.type()\n        return torch_type.endswith('FloatTensor') or \\\n            torch_type.endswith('HalfTensor') or \\\n            torch_type.endswith('DoubleTensor')\n    except AttributeError:\n        return False\n\ndef scalar_python_val(x):\n    if hasattr(x, 'item'):\n        return x.item()\n    else:\n        if isinstance(x, torch.autograd.Variable):\n            return x.data[0]\n        else:\n            return x[0]\n\n# Accounts for the possibility that some ops may be removed from a namespace.\ndef filter_attrs(module, attrs):\n    return list(attrname for attrname in attrs if hasattr(module, attrname))\n"
  },
  {
    "path": "KoSimCSE/apex/amp/frontend.py",
    "content": "import torch\nfrom ._initialize import _initialize\nfrom ._amp_state import _amp_state, warn_or_err, maybe_print\nfrom collections import OrderedDict\n\n\nclass Properties(object):\n    \"\"\"\n    This class has two purposes: to establish a set of default properties,\n    and to route setting of these attributes through __setattr__ so that (in theory)\n    they can be checked for consistency with other existing args.\n    \"\"\"\n    def __init__(self):\n        self.options = {\n            \"enabled\" : False,\n            \"opt_level\" : None,\n            \"cast_model_type\" : None,\n            \"patch_torch_functions\" : False,\n            \"keep_batchnorm_fp32\" : None,\n            \"master_weights\" : None,\n            \"loss_scale\" : 1.0,\n            # Reserved for future functionality\n            # \"fused_optimizer\" : False,\n            # \"enable_ddp_interop\" : False,\n            }\n\n    \"\"\"\n    This function allows updating several options at a time without routing through\n    __setattr__ checks, to avoid \"you can't get there from here\" scenarios.\n    Currently not intended to be exposed; users are expected to select an opt_level\n    and apply consistent modifications.\n    \"\"\"\n    def _update_options_dict(self, new_options):\n        for k, v in new_options:\n            if k in self.options:\n                self.options[k] = v\n            else:\n                raise ValueError(\"Tried to set unexpected option {}\".format(k))\n    \"\"\"\n    The members of \"options\" are not direct attributes of self, so access attempts\n    will roll down to __getattr__.  This borrows from the logic in torch.nn.Module.\n    \"\"\"\n    def __getattr__(self, name):\n        if \"options\" in self.__dict__:\n            options =  self.__dict__[\"options\"]\n            if name in options:\n                return options[name]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, name))\n\n    def __setattr__(self, name, value):\n        if \"options\" in self.__dict__:\n            if name in self.options:\n                # print(\"setting {} {}\".format(name, value))\n                if name == \"cast_model_type\":\n                    if self.opt_level == \"O1\" and value is not None:\n                        if value is not False:\n                            if value is not torch.float32:\n                                warn_or_err(\"O1 inserts casts around Torch functions rather than \"\n                                            \"model weights, so with O1, the model weights themselves \"\n                                            \"should remain FP32. If you wish to cast the model to a \"\n                                            \"different type, use opt_level='O2' or 'O3'. \" +\n                                            \"cast_model_type was {}\".format(value))\n                    self.options[name] = value\n                elif name == \"patch_torch_functions\":\n                    if self.opt_level != \"O1\" and value:\n                        warn_or_err(\"Currently, patch_torch_functions=True should only be set by \"\n                                    \"selecting opt_level='O1'.\")\n                    self.options[name] = value\n                elif name == \"keep_batchnorm_fp32\":\n                    if self.opt_level == \"O1\" and value is not None:\n                        warn_or_err(\"With opt_level O1, batchnorm functions are automatically patched \"\n                                    \"to run in FP32, so keep_batchnorm_fp32 should be None.\" +\n                                    \" keep_batchnorm_fp32 was {}\".format(value))\n                    if value == \"False\":\n                        self.options[name] = False\n                    elif value == \"True\":\n                        self.options[name] = True\n                    else:\n                        assert (value is True or value is False or value is None),\\\n                            \"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', \"\\\n                            \"or None, found keep_batchnorm_fp32={}\".format(value)\n                        self.options[name] = value\n                elif name == \"master_weights\":\n                    if self.opt_level == \"O1\" and value is not None:\n                        warn_or_err(\"It doesn't make sense to use master_weights with O1. \"\n                                    \"With O1, your model weights themselves should be FP32.\")\n                    self.options[name] = value\n                elif name == \"loss_scale\":\n                    if value == \"dynamic\":\n                        self.options[name] = value\n                    else:\n                        self.options[name] = float(value)\n                else:\n                    self.options[name] = value\n        else:\n            super(Properties, self).__setattr__(name, value)\n\n\n\"\"\" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. \"\"\"\n\nclass O3:\n    brief = \"O3:  Pure FP16 training.\"\n    more = \"Calls .half() on your model, converting the entire model to FP16.\\n\"\\\n        \"A casting operation is also inserted to cast incoming Tensors to FP16,\\n\"\\\n        \"so you don't need to change your data pipeline.\\n\"\\\n        \"This mode is useful for establishing a performance ceiling.\\n\"\\\n        \"It's also possible training may 'just work' in this mode.\\n\"\\\n        \"If not, try other optimization levels.\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O3\"\n        properties.cast_model_type = torch.float16\n        properties.patch_torch_functions = False\n        properties.keep_batchnorm_fp32 = False\n        properties.master_weights = False\n        properties.loss_scale = 1.0\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nclass O2:\n    brief = \"O2:  FP16 training with FP32 batchnorm and FP32 master weights.\\n\"\n    more = \"Calls .half() on your model, converting the entire model (except for batchnorms)\\n\"\\\n        \"to FP16.  Batchnorms are retained in FP32 for additional stability.\\n\"\\\n        \"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\\n\"\\\n        \"your data pipeline.\\n\"\\\n        \"O2 creates FP32 master weights outside the model and patches any optimizers to update\\n\"\\\n        \"these master weights, then copy the master weights into the FP16 model weights.\\n\"\\\n        \"Master weights can also improve convergence and stability.\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O2\"\n        properties.cast_model_type = torch.float16\n        properties.patch_torch_functions = False\n        properties.keep_batchnorm_fp32 = True\n        properties.master_weights = True\n        properties.loss_scale = \"dynamic\"\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nclass O1:\n    brief = \"O1:  Insert automatic casts around Pytorch functions and Tensor methods.\\n\"\n    more = \"The type of your model's weights is not altered.  However, internally,\\n\"\\\n        \"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\\n\"\\\n        \"while operations that might benefit from the additional stability of FP32 are patched\\n\"\\\n        \"to cast their inputs to fp32.\\n\"\\\n        \"O1 is the safest way to try mixed precision training, and is recommended when\\n\"\\\n        \"trying mixed precision training for the first time.\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O1\"\n        properties.cast_model_type = None\n        properties.patch_torch_functions = True\n        properties.keep_batchnorm_fp32 = None\n        properties.master_weights = None\n        properties.loss_scale = \"dynamic\"\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nclass O0:\n    brief = \"O0:  Pure FP32 training.\\n\"\n    more = \"Your models are checked to make sure parameters are FP32, but otherwise the\\n\"\\\n        \"types of weights and internal Pytorch operations are not altered.  This mode disables any\\n\"\\\n        \"FP16 arithmetic, although other optimizations like DDP interop may still be requested.\\n\"\n\n    def __call__(self, properties):\n        properties.enabled = True\n        properties.opt_level = \"O0\"\n        properties.cast_model_type = torch.float32\n        properties.patch_torch_functions = False\n        properties.keep_batchnorm_fp32 = None\n        properties.master_weights = False\n        properties.loss_scale = 1.0\n        # properties.fused_optimizer = False\n        # properties.enable_ddp_interop = False\n        return properties # modified in place so this isn't really necessary\n\n\nopt_levels = {\"O3\": O3(),\n              \"O2\": O2(),\n              \"O1\": O1(),\n              \"O0\": O0()}\n\n\n# allow user to directly pass Properties struct as well?\ndef initialize(\n    models,\n    optimizers=None,\n    enabled=True,\n    opt_level=\"O1\",\n    cast_model_type=None,\n    patch_torch_functions=None,\n    keep_batchnorm_fp32=None,\n    master_weights=None,\n    loss_scale=None,\n    cast_model_outputs=None,\n    num_losses=1,\n    verbosity=1,\n    min_loss_scale=None,\n    max_loss_scale=2.**24\n    ):\n    \"\"\"\n    Initialize your models, optimizers, and the Torch tensor and functional namespace according to the\n    chosen ``opt_level`` and overridden properties, if any.\n\n    ``amp.initialize`` should be called **after** you have finished\n    constructing your model(s) and\n    optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.\n    See `Distributed training`_ in the Imagenet example.\n\n    Currently, ``amp.initialize`` should only be called **once**,\n    although it can process an arbitrary number of\n    models and optimizers (see the corresponding `Advanced Amp Usage topic`_).\n    If you think your use case requires ``amp.initialize`` to be called more than once,\n    `let us know`_.\n\n    Any property keyword argument that is not ``None`` will be interpreted as a manual override.\n\n    To prevent having to rewrite anything else in your script, name the returned models/optimizers\n    to replace the passed models/optimizers, as in the code sample below.\n\n    Args:\n        models (torch.nn.Module or list of torch.nn.Modules):  Models to modify/cast.\n        optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers):  Optimizers to modify/cast.\n            REQUIRED for training, optional for inference.\n        enabled (bool, optional, default=True):  If False, renders all Amp calls no-ops, so your script\n            should run as if Amp were not present.\n        opt_level (str, optional, default=\"O1\"):  Pure or mixed precision optimization level.  Accepted values are\n            \"O0\", \"O1\", \"O2\", and \"O3\", explained in detail above.\n        cast_model_type (``torch.dtype``, optional, default=None):  Optional property override, see\n            above.\n        patch_torch_functions (bool, optional, default=None):  Optional property override.\n        keep_batchnorm_fp32 (bool or str, optional, default=None):  Optional property override.  If\n            passed as a string, must be the string \"True\" or \"False\".\n        master_weights (bool, optional, default=None):  Optional property override.\n        loss_scale (float or str, optional, default=None):  Optional property override.  If passed as a string,\n            must be a string representing a number, e.g., \"128.0\", or the string \"dynamic\".\n        cast_model_outputs (torch.dtype, optional, default=None):  Option to ensure that the outputs\n            of your model(s) are always cast to a particular type regardless of ``opt_level``.\n        num_losses (int, optional, default=1):  Option to tell Amp in advance how many losses/backward\n            passes you plan to use.  When used in conjunction with the ``loss_id`` argument to\n            ``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,\n            which can improve stability.  See \"Multiple models/optimizers/losses\"\n            under `Advanced Amp Usage`_ for examples.  If ``num_losses`` is left to 1, Amp will still\n            support multiple losses/backward passes, but use a single global loss scale\n            for all of them.\n        verbosity (int, default=1):  Set to 0 to suppress Amp-related output.\n        min_loss_scale (float, default=None):  Sets a floor for the loss scale values that can be chosen by dynamic\n            loss scaling.  The default value of None means that no floor is imposed.\n            If dynamic loss scaling is not used, `min_loss_scale` is ignored.\n        max_loss_scale (float, default=2.**24):  Sets a ceiling for the loss scale values that can be chosen by\n            dynamic loss scaling.  If dynamic loss scaling is not used, `max_loss_scale` is ignored.\n\n    Returns:\n        Model(s) and optimizer(s) modified according to the ``opt_level``.\n        If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will\n        also be a list.\n\n    Permissible invocations::\n\n        model, optim = amp.initialize(model, optim,...)\n        model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)\n        [model1, model2], optim = amp.initialize([model1, model2], optim,...)\n        [model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)\n\n        # This is not an exhaustive list of the cross product of options that are possible,\n        # just a set of examples.\n        model, optim = amp.initialize(model, optim, opt_level=\"O0\")\n        model, optim = amp.initialize(model, optim, opt_level=\"O0\", loss_scale=\"dynamic\"|128.0|\"128.0\")\n\n        model, optim = amp.initialize(model, optim, opt_level=\"O1\") # uses \"loss_scale=\"dynamic\" default\n        model, optim = amp.initialize(model, optim, opt_level=\"O1\", loss_scale=128.0|\"128.0\")\n\n        model, optim = amp.initialize(model, optim, opt_level=\"O2\") # uses \"loss_scale=\"dynamic\" default\n        model, optim = amp.initialize(model, optim, opt_level=\"O2\", loss_scale=128.0|\"128.0\")\n        model, optim = amp.initialize(model, optim, opt_level=\"O2\", keep_batchnorm_fp32=True|False|\"True\"|\"False\")\n\n        model, optim = amp.initialize(model, optim, opt_level=\"O3\") # uses loss_scale=1.0 default\n        model, optim = amp.initialize(model, optim, opt_level=\"O3\", loss_scale=\"dynamic\"|128.0|\"128.0\")\n        model, optim = amp.initialize(model, optim, opt_level=\"O3\", keep_batchnorm_fp32=True|False|\"True\"|\"False\")\n\n    The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.\n\n    .. _`Distributed training`:\n        https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training\n\n    .. _`Imagenet example`:\n        https://github.com/NVIDIA/apex/tree/master/examples/imagenet\n\n    .. _`Advanced Amp Usage`:\n        https://nvidia.github.io/apex/advanced.html\n\n    .. _`Advanced Amp Usage topic`:\n        https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses\n\n    .. _`let us know`:\n        https://github.com/NVIDIA/apex/issues\n    \"\"\"\n    _amp_state.opt_properties = Properties()\n    _amp_state.verbosity = verbosity\n\n    if not enabled:\n        if optimizers is None:\n            return models\n        else:\n            return models, optimizers\n\n    if not torch.backends.cudnn.enabled:\n        raise RuntimeError(\n            \"Amp requires torch.backends.cudnn.enabled = True\")\n\n    if opt_level not in opt_levels:\n        raise RuntimeError(\n            \"Unexpected optimization level {}. \".format(opt_level) +\n            \"Options are 'O0', 'O1', 'O2', 'O3'.  Note that in `O0`, `O1`, etc., the prefix O is the letter O, \" +\n            \"not the number zero.\")\n    else:\n        _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)\n        maybe_print(\"Selected optimization level {}\".format(opt_levels[opt_level].brief), True)\n        maybe_print(\"Defaults for this optimization level are:\", True)\n        for k, v in _amp_state.opt_properties.options.items():\n            maybe_print(\"{:22} : {}\".format(k, v), True)\n\n    _amp_state.min_loss_scale = min_loss_scale\n    _amp_state.max_loss_scale = max_loss_scale\n\n    maybe_print(\"Processing user overrides (additional kwargs that are not None)...\", True)\n    # I chose to have the keyword arguments listed directly in the argument list,\n    # instead of **kwargs, so I can't use kwargs.items() here.\n    if enabled is not None:\n        _amp_state.opt_properties.enabled = enabled\n    if opt_level is not None:\n        _amp_state.opt_properties.opt_level = opt_level\n    if cast_model_type is not None:\n        _amp_state.opt_properties.cast_model_type = cast_model_type\n    if patch_torch_functions is not None:\n        _amp_state.opt_properties.patch_torch_functions = patch_torch_functions\n    if keep_batchnorm_fp32 is not None:\n        _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32\n    if master_weights is not None:\n        _amp_state.opt_properties.master_weights = master_weights\n    if loss_scale is not None:\n        _amp_state.opt_properties.loss_scale = loss_scale\n\n    maybe_print(\"After processing overrides, optimization options are:\", True)\n    for k, v in _amp_state.opt_properties.options.items():\n        maybe_print(\"{:22} : {}\".format(k, v), True)\n\n    return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)\n\n\ndef state_dict(destination=None):\n    if destination is None:\n        destination = OrderedDict()\n\n    for idx, loss_scaler in enumerate(_amp_state.loss_scalers):\n        destination['loss_scaler%d' % idx] = {\n            'loss_scale': loss_scaler.loss_scale(),\n            'unskipped': loss_scaler._unskipped,\n        }\n    return destination\n\n\ndef load_state_dict(state_dict):\n    # Check if state_dict containes the same number of loss_scalers as current setup\n    if len(state_dict) != len(_amp_state.loss_scalers):\n        print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format(\n            len(state_dict), len(_amp_state.loss_scalers)))\n\n    state_dict = state_dict.copy()\n    \n    nb_loss_scalers = len(_amp_state.loss_scalers)\n    unexpected_keys = []\n    # Initialize idx outside, since unexpected_keys will increase it if enumerate is used\n    idx = 0\n    for key in state_dict:\n        if 'loss_scaler' not in key:\n            unexpected_keys.append(key)\n        else:\n            if idx > (nb_loss_scalers - 1):\n                print('Skipping loss_scaler[{}], since num_losses was set to {}'.format(\n                    idx, nb_loss_scalers))\n                break\n            _amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale']\n            _amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped']\n            idx += 1\n\n    if len(unexpected_keys) > 0:\n        raise RuntimeError(\n            'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format(\n                ', '.join('\"{}\"'.format(k) for k in unexpected_keys)))\n\n\n# TODO:  is this necessary/useful?\n# def check_option_consistency(enabled=True,\n#                              opt_level=None,\n#                              cast_model_type=None,\n#                              patch_torch_functions=None,\n#                              keep_batchnorm_fp32=None,\n#                              master_weights=None,\n#                              loss_scale=None,\n#                              enable_ddp_interop=None,\n#                              hard_override=False):\n#     \"\"\"\n#     Utility function that enables users to quickly check if the option combination they intend\n#     to use is permitted.  ``check_option_consistency`` does not require models or optimizers\n#     to be constructed, and can be called at any point in the script.  ``check_option_consistency``\n#     is totally self-contained; it does not set any amp global state or affect anything outside\n#     of itself.\n#     \"\"\"\n#\n#     if not enabled:\n#         return\n#\n#     if opt_level not in opt_levels:\n#         raise RuntimeError(\"Unexpected optimization level.  Options are 'O0', 'O1', 'O2', 'O3'.\")\n#     else:\n#         opt_properties = opt_levels[opt_level](Properties())\n#         print(\"Selected optimization level {}\", opt_levels[opt_level].brief)\n#         print(\"Defaults for this optimization level are:\")\n#         for k, v in opt_properties.options:\n#             print(\"{:22} : {}\".format(k, v))\n#\n#     print(\"Processing user overrides (additional kwargs that are not None)...\")\n#     for k, v in kwargs:\n#         if k not in _amp_state.opt_properties.options:\n#             raise RuntimeError(\"Unexpected kwarg {}\".format(k))\n#         if v is not None:\n#             setattr(opt_properties, k, v)\n#\n#     print(\"After processing overrides, optimization options are:\")\n#     for k, v in opt_properties.options:\n#         print(\"{:22} : {}\".format(k, v))\n"
  },
  {
    "path": "KoSimCSE/apex/amp/handle.py",
    "content": "import contextlib\nimport warnings\nimport sys\nimport torch\n\nfrom . import utils\nfrom .opt import OptimWrapper\nfrom .scaler import LossScaler\nfrom ._amp_state import _amp_state, master_params, maybe_print\n\nif torch.distributed.is_available():\n    from ..parallel.LARC import LARC\n\n\n# There's no reason to expose the notion of a \"handle\". Everything can happen through amp.* calls.\n@contextlib.contextmanager\ndef scale_loss(loss,\n               optimizers,\n               loss_id=0,\n               model=None,\n               delay_unscale=False,\n               delay_overflow_check=False):\n    \"\"\"\n    On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.\n    ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::\n\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n\n    On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs\n    and unscaled, so that ``optimizer.step()`` can be called.\n\n    .. note::\n        If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and\n        can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)\n        any FP16 gradients are copied to FP32 master gradients before being unscaled.\n        ``optimizer.step()`` will then apply the unscaled master gradients to the master params.\n\n    .. warning::\n        If Amp is using explicit FP32 master params, only the FP32 master gradients will be\n        unscaled.  The direct ``.grad`` attributes of any FP16\n        model params will remain scaled after context manager exit.\n        This subtlety affects gradient clipping.  See \"Gradient clipping\" under\n        `Advanced Amp Usage`_ for best practices.\n\n    Args:\n        loss(Tensor):  Typically a scalar Tensor. The ``scaled_loss`` that the context\n            manager yields is simply ``loss.float()*loss_scale``, so in principle\n            ``loss`` could have more than one element, as long as you call\n            ``backward()`` on ``scaled_loss`` appropriately within the context manager body.\n        optimizers:  All optimizer(s) for which the current backward pass is creating gradients.\n            Must be an optimizer or list of optimizers returned from an earlier call\n            to ``amp.initialize``.  For example use with multiple optimizers, see\n            \"Multiple models/optimizers/losses\" under `Advanced Amp Usage`_.\n        loss_id(int, optional, default=0):  When used in conjunction with the ``num_losses`` argument\n            to ``amp.initialize``, enables Amp to use a different loss scale per loss.  ``loss_id``\n            must be an integer between 0 and ``num_losses`` that tells Amp which loss is\n            being used for the current backward pass.  See \"Multiple models/optimizers/losses\"\n            under `Advanced Amp Usage`_ for examples.  If ``loss_id`` is left unspecified, Amp\n            will use the default global loss scaler for this backward pass.\n        model(torch.nn.Module, optional, default=None):  Currently unused, reserved to enable future\n            optimizations.\n        delay_unscale(bool, optional, default=False):  ``delay_unscale`` is never necessary, and\n            the default value of ``False`` is strongly recommended.\n            If ``True``, Amp will not unscale the gradients or perform model->master\n            gradient copies on context manager exit.\n            ``delay_unscale=True`` is a minor ninja performance optimization and can result\n            in weird gotchas (especially with multiple models/optimizers/losses),\n            so only use it if you know what you're doing.\n            \"Gradient accumulation across iterations\" under `Advanced Amp Usage`_\n            illustrates a situation where this CAN (but does not need to) be used.\n\n    .. warning::\n        If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be\n        called yet after context manager exit, and must wait for another, later backward context\n        manager invocation with ``delay_unscale`` left to False.\n\n    .. _`Advanced Amp Usage`:\n        https://nvidia.github.io/apex/advanced.html\n    \"\"\"\n    if not hasattr(_amp_state, \"opt_properties\"):\n        raise RuntimeError(\"Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized.  \"\n                           \"model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called \"\n                           \"before `with amp.scale_loss`.\")\n\n    if not _amp_state.opt_properties.enabled:\n        yield loss\n        return\n\n    if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):\n        optimizers = [optimizers]\n\n    loss_scaler = _amp_state.loss_scalers[loss_id]\n    loss_scale = loss_scaler.loss_scale()\n\n    if ((not _amp_state.opt_properties.master_weights)\n        and (not loss_scaler.dynamic)\n        and loss_scale == 1.0):\n        yield loss.float()\n        # Needing to drop the cache here as well is an ugly gotcha.\n        # But for now I think it's necessary to short-circuit.\n        # Probably ok to skip this if not delay_unscale\n        if _amp_state.opt_properties.patch_torch_functions:\n            _amp_state.handle._clear_cache()\n        return\n\n    if not delay_unscale:\n        if isinstance(optimizers, list):\n            for optimizer in optimizers:\n                if not optimizer._amp_stash.params_have_scaled_gradients:\n                    optimizer._prepare_amp_backward()\n\n    yield (loss.float())*loss_scale\n\n    if delay_unscale:\n        for optimizer in optimizers:\n            optimizer._amp_stash.params_have_scaled_gradients = True\n    else:\n        # FusedSGD may take care of unscaling as part of their step() methods.\n        # if not isinstance(optimizers, FP16_Optimizer_for_fused):\n            loss_scaler.clear_overflow_state()\n            for optimizer in optimizers:\n                optimizer._post_amp_backward(loss_scaler)\n                optimizer._amp_stash.params_have_scaled_gradients = False\n            # For future fused optimizers that enable sync-free dynamic loss scaling,\n            # should_skip will always be False.\n            should_skip = False if delay_overflow_check else loss_scaler.update_scale()\n            if should_skip:\n                for optimizer in optimizers:\n                    if not optimizer._amp_stash.already_patched:\n                        # Close on loss_scaler and loss_id as well, to be safe.  Probably not\n                        # necessary because amp.scale_loss is already creating a temporary scope.\n                        def patch_step(opt, loss_scaler, loss_id):\n                            opt_step = opt.step\n                            def skip_step(closure=None):\n                                if closure is not None:\n                                    raise RuntimeError(\"Currently, Amp does not support closure use with optimizers.\")\n                                maybe_print((\"Gradient overflow.  Skipping step, loss scaler \" +\n                                             \"{} reducing loss scale to {}\").format(loss_id,\n                                             loss_scaler.loss_scale()))\n                                # TODO:  I don't like the special casing for different optimizer implementations.\n                                # Maybe skip should delegate to a method owned by the optimizers themselves.\n                                if hasattr(opt._amp_stash, \"all_fp32_from_fp16_params\"):\n                                    # Clear the master grads that wouldn't be zeroed by model.zero_grad()\n                                    for param in opt._amp_stash.all_fp32_from_fp16_params:\n                                        param.grad = None\n                                if hasattr(opt, \"most_recent_scale\"):\n                                    opt.most_recent_scale = 1.0\n                                    opt.scale_set_by_backward = False\n                                opt.step = opt_step\n                                opt._amp_stash.already_patched = False\n                            return skip_step\n                        optimizer.step = patch_step(optimizer, loss_scaler, loss_id)\n                        optimizer._amp_stash.already_patched = True\n\n    # Probably ok to skip this if not delay_unscale\n    if _amp_state.opt_properties.patch_torch_functions:\n        _amp_state.handle._clear_cache()\n\n\n# Free function version of AmpHandle.disable_casts, another step on the\n# path to removing the concept of \"AmpHandle\"\n@contextlib.contextmanager\ndef disable_casts():\n    _amp_state.handle._is_active = False\n    yield\n    _amp_state.handle._is_active = True\n\n\nclass AmpHandle(object):\n    def __init__(self, loss_scale=\"dynamic\", enable_caching=True, verbose=False):\n        self._enable_caching = enable_caching\n        self._verbose = verbose\n        self._cache = dict()\n        self._default_scaler = LossScaler(loss_scale)\n        self._is_active = True\n        self._all_wrappers = []\n\n    def is_active(self):\n        return self._is_active\n\n    @contextlib.contextmanager\n    def _disable_casts(self):\n        self._is_active = False\n        yield\n        self._is_active = True\n\n    def wrap_optimizer(self, optimizer, num_loss=1):\n        self._default_scaler = None\n        return OptimWrapper(optimizer, self, num_loss)\n\n    @contextlib.contextmanager\n    def scale_loss(self, loss, optimizer):\n        raise RuntimeError(\"The old Amp API is no longer supported.  Please move to the new API, \"\n            \"documented here:  https://nvidia.github.io/apex/amp.html.  Transition guide:  \"\n            \"https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users\")\n\n        if not self.is_active():\n            yield loss\n            return\n\n        if self._default_scaler is None:\n            raise RuntimeError(\n                'After calling `handle.wrap_optimizer()`, you must explicitly ' +\n                'use `optimizer.scale_loss(loss)`.')\n\n        # TODO: this code block is duplicated here and `opt.py`. Unify.\n        loss_scale = self._default_scaler.loss_scale()\n        yield loss * loss_scale\n\n        self._default_scaler.clear_overflow_state()\n        self._default_scaler.unscale(\n            master_params(optimizer),\n            master_params(optimizer),\n            loss_scale)\n        should_skip = self._default_scaler.update_scale()\n        if should_skip:\n            optimizer_step = optimizer.step\n            def skip_step():\n                maybe_print('Gradient overflow, skipping update')\n                optimizer.step = optimizer_step\n            optimizer.step = skip_step\n\n        self._clear_cache()\n\n    def _clear_cache(self):\n        self._cache.clear()\n\n    # Experimental support for saving / restoring uncasted versions of functions\n    def _save_func(self, mod, fn, func):\n        self._all_wrappers.append((mod, fn, func))\n\n    def _deactivate(self):\n        for mod, fn, func in self._all_wrappers:\n            utils.set_func(mod, fn, func)\n        self._all_wrappers = []\n\n    @property\n    def has_cache(self):\n        return self._enable_caching\n\n    @property\n    def cache(self):\n        return self._cache\n\n    def remove_cache(self, param):\n        if self.has_cache and param in self.cache:\n            del self.cache[param]\n\n    @property\n    def verbose(self):\n        return self._verbose\n\nclass NoOpHandle(object):\n    def is_active(self):\n        return False\n\n    @contextlib.contextmanager\n    def _disable_casts(self):\n        yield\n\n    def wrap_optimizer(self, optimizer, num_loss=1):\n        return OptimWrapper(optimizer, self, num_loss)\n\n    @contextlib.contextmanager\n    def scale_loss(self, loss, optimizer):\n        yield loss\n\n    @property\n    def has_cache(self):\n        return False\n\n    @property\n    def verbose(self):\n        return False\n\n    def _clear_cache(self):\n        pass\n\n    def _deactivate(self):\n        pass\n"
  },
  {
    "path": "KoSimCSE/apex/amp/lists/__init__.py",
    "content": ""
  },
  {
    "path": "KoSimCSE/apex/amp/lists/functional_overrides.py",
    "content": "\n# TODO: think about the following two. They do weird things.\n# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)\n# - torch.nn.utils.weight_norm\n\n# Notes:\n# F.instance_norm uses batch_norm internally. Which correctly handles\n#   fp16 in/out with fp32 weights. So we shouldn't do anything for\n#   either of these.\n# F.normalize calls `input.norm()` internally, so it's redundant, but\n#   kept here in case impl. changes.\n# F.cosine_similarity is same: calls `x.norm()` internally.\n\nimport torch.nn.functional\n\nMODULE = torch.nn.functional\n\nFP16_FUNCS = [\n    'conv1d',\n    'conv2d',\n    'conv3d',\n    'conv_transpose1d',\n    'conv_transpose2d',\n    'conv_transpose3d',\n    'conv_tbc', # Undocumented / maybe new?\n    'linear',\n]\n\nFP32_FUNCS = [\n\n    # Interpolation/Upsampling TODO:  Remove for 1.2\n    'interpolate',\n    'grid_sample',\n\n    # Pointwise\n    'softplus',\n    'softmin',\n    'log_softmax',\n    'softmax',\n    'gelu',\n    \n    # Normalization\n    'layer_norm',\n    'group_norm',\n    'local_response_norm',\n    'normalize',\n    'cosine_similarity',\n\n    # Loss functions\n    # TODO: which of these can be fp16?\n    'poisson_nll_loss',\n    'cosine_embedding_loss',\n    'cross_entropy',\n    'hinge_embedding_loss',\n    'kl_div',\n    'l1_loss',\n    'mse_loss',\n    'margin_ranking_loss',\n    'multilabel_margin_loss',\n    'multilabel_soft_margin_loss',\n    'multi_margin_loss',\n    'nll_loss',\n    'binary_cross_entropy_with_logits',\n    'smooth_l1_loss',\n    'soft_margin_loss',\n    'triplet_margin_loss',\n    'ctc_loss'\n]\n\nBANNED_FUNCS = [\n    ('binary_cross_entropy',\n     (\"\\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` \"\n      \"It requires that the output of the previous function be already a FloatTensor. \\n\\n\"\n      \"Most models have a Sigmoid right before BCELoss. In that case, you can use\\n\"\n      \"    torch.nn.BCEWithLogitsLoss\\nto combine Sigmoid+BCELoss into a single layer \"\n      \"that is compatible with amp.\\nAnother option is to add\\n\"\n      \"    amp.register_float_function(torch, 'sigmoid')\\nbefore calling `amp.init()`.\\n\"\n      \"If you _really_ know what you are doing, you can disable this warning by passing \"\n      \"allow_banned=True to `amp.init()`.\"))\n]\n"
  },
  {
    "path": "KoSimCSE/apex/amp/lists/tensor_overrides.py",
    "content": "from .. import compat\nfrom . import torch_overrides\n\nimport importlib\n\nimport torch\n\n# if compat.variable_is_tensor() and not compat.tensor_is_variable():\nMODULE = torch.Tensor\n# else:\n#     MODULE = torch.autograd.Variable\n\n\nFP16_FUNCS = compat.filter_attrs(MODULE, [\n    '__matmul__',\n])\n\nFP32_FUNCS = compat.filter_attrs(MODULE, [\n    '__ipow__',\n    '__pow__',\n    '__rpow__',\n\n    # Cast to fp32 before transfer to CPU\n    'cpu',\n])\n\nCASTS = compat.filter_attrs(MODULE, [\n    '__add__',\n    '__div__',\n    '__eq__',\n    '__ge__',\n    '__gt__',\n    '__iadd__',\n    '__idiv__',\n    '__imul__',\n    '__isub__',\n    '__itruediv__',\n    '__le__',\n    '__lt__',\n    '__mul__',\n    '__ne__',\n    '__radd__',\n    '__rdiv__',\n    '__rmul__',\n    '__rsub__',\n    '__rtruediv__',\n    '__sub__',\n    '__truediv__',\n])\n\n# None of these, but here to make code cleaner.\nSEQUENCE_CASTS = []\n\n# We need to grab all the methods from torch_overrides and add them to\n# the Tensor lists as well, as almost all methods are duplicated\n# between `torch` and `torch.Tensor` (and check with `hasattr`,\n# because a few random ones aren't defined on Tensor)\n_self_mod = importlib.import_module(__name__)\nfor attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:\n    lst = getattr(_self_mod, attrname)\n    for fn in getattr(torch_overrides, attrname):\n        if hasattr(MODULE, fn):\n            lst.append(fn)\n"
  },
  {
    "path": "KoSimCSE/apex/amp/lists/torch_overrides.py",
    "content": "import torch\n\nfrom .. import utils\n\nMODULE = torch\n\nFP16_FUNCS = [\n    # Low level functions wrapped by torch.nn layers.\n    # The wrapper layers contain the weights which are then passed in as a parameter\n    # to these functions.\n    'conv1d',\n    'conv2d',\n    'conv3d',\n    'conv_transpose1d',\n    'conv_transpose2d',\n    'conv_transpose3d',\n    'conv_tbc',\n    'prelu',\n\n    # BLAS\n    'addmm',\n    'addmv',\n    'addr',\n    'matmul',\n    'mm',\n    'mv',\n]\n\nFP32_FUNCS = [\n    # Pointwise\n    'acos',\n    'asin',\n    'cosh',\n    'erfinv',\n    'exp',\n    'expm1',\n    'log',\n    'log10',\n    'log2',\n    'reciprocal',\n    'rsqrt',\n    'sinh',\n    'tan',\n\n    # Other math\n    'pow',\n\n    # Reduction\n    'cumprod',\n    'cumsum',\n    'dist',\n    # 'mean',\n    'norm',\n    'prod',\n    'std',\n    'sum',\n    'var',\n\n    # Misc\n    'renorm'\n]\n\nversion_strings = torch.__version__.split('.')\nversion_major = version_strings[0]\nversion_minor = version_strings[1]\nversion_num = float(version_major + \".\" + version_minor)\n# Before torch 1.1, mean must be blacklisted.\nif version_num < 1.1:\n    FP32_FUNCS.append('mean')\n\n# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We\n# check the CUDA version -- if at least 9.1, then put the bmm\n# functions on the fp16 list. Otherwise, put them on the fp32 list.\n_bmms = ['addbmm',\n         'baddbmm',\n         'bmm']\n\nif utils.is_cuda_enabled():\n  # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802\n  if utils.get_cuda_version() >= (9, 1, 0):\n      FP16_FUNCS.extend(_bmms)\n  else:\n      FP32_FUNCS.extend(_bmms)\n\n# Multi-tensor fns that may need type promotion\nCASTS = [\n    # Multi-tensor math\n    'addcdiv',\n    'addcmul',\n    'atan2',\n    'cross',\n    'bilinear',\n    'dot',\n\n    # Element-wise _or_ tensor-wise math\n    'add',\n    'div',\n    'mul',\n\n    # Comparison\n    'eq',\n    'equal',\n    'ge',\n    'gt',\n    'le',\n    'lt',\n    'ne'\n]\n\n# Functions that take sequence arguments. We need to inspect the whole\n# sequence and cast to the widest type.\nSEQUENCE_CASTS = [\n    'cat',\n    'stack'\n]\n"
  },
  {
    "path": "KoSimCSE/apex/amp/opt.py",
    "content": "import contextlib\nimport warnings\n\nfrom .scaler import LossScaler, master_params\nfrom ._amp_state import maybe_print\n\nimport numpy as np\n\nclass OptimWrapper(object):\n    def __init__(self, optimizer, amp_handle, num_loss):\n        self._optimizer = optimizer\n        self._amp_handle = amp_handle\n        self._num_loss = num_loss\n        self._loss_idx = 0\n        self._skip_next = [False] * num_loss\n        self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]\n\n    @contextlib.contextmanager\n    def scale_loss(self, loss):\n        if not self._amp_handle.is_active():\n            yield loss\n            return\n\n        # When there are multiple losses per-optimizer, we need\n        # to save out current grad accumulation, since we won't be\n        # able to unscale this particulare loss once the grads are\n        # all mixed together.\n        cached_grads = []\n        if self._loss_idx > 0:\n            for p in master_params(self._optimizer):\n                if p.grad is not None:\n                    cached_grads.append(p.grad.data.detach().clone())\n                else:\n                    cached_grads.append(None)\n            self._optimizer.zero_grad()\n\n        loss_scale = self._cur_loss_scaler().loss_scale()\n        yield loss * loss_scale\n\n        self._cur_loss_scaler().clear_overflow_state()\n        self._cur_loss_scaler().unscale(\n            master_params(self._optimizer),\n            master_params(self._optimizer),\n            loss_scale)\n        self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()\n        self._loss_idx += 1\n\n        if len(cached_grads) > 0:\n            for p, cached_grad in zip(master_params(self._optimizer),\n                                      cached_grads):\n                if cached_grad is not None:\n                    p.grad.data.add_(cached_grad)\n            cached_grads = []\n\n    def _cur_loss_scaler(self):\n        assert 0 <= self._loss_idx < self._num_loss\n        return self._loss_scaler[self._loss_idx]\n\n    def step(self, closure=None):\n        if not self._amp_handle.is_active():\n            return self._optimizer.step(closure=closure)\n\n        self._loss_idx = 0\n\n        for group in self._optimizer.param_groups:\n            for p in group['params']:\n                self._amp_handle.remove_cache(p)\n\n        if closure is not None:\n            raise NotImplementedError(\n                'The `closure` argument is unsupported by the amp ' +\n                'optimizer wrapper.')\n        if any(self._skip_next):\n            maybe_print('Gradient overflow, skipping update')\n            self._skip_next = [False] * self._num_loss\n        else:\n            return self._optimizer.step(closure=closure)\n\n    # Forward any attribute lookups\n    def __getattr__(self, attr):\n        return getattr(self._optimizer, attr)\n\n    # Forward all torch.optim.Optimizer methods\n    def __getstate__(self):\n        return self._optimizer.__getstate__()\n\n    def __setstate__(self):\n        return self._optimizer.__setstate__()\n\n    def __repr__(self):\n        return self._optimizer.__repr__()\n\n    def state_dict(self):\n        return self._optimizer.state_dict()\n\n    def load_state_dict(self, state_dict):\n        return self._optimizer.load_state_dict(state_dict)\n\n    def zero_grad(self):\n        return self._optimizer.zero_grad()\n\n    def add_param_group(self, param_group):\n        return self._optimizer.add_param_group(param_group)\n"
  },
  {
    "path": "KoSimCSE/apex/amp/rnn_compat.py",
    "content": "from . import utils, wrap\n\nimport torch\n_VF = torch._C._VariableFunctions\nRNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']\n\ndef _gen_VF_wrapper(name):\n    def wrapper(*args, **kwargs):\n        return getattr(_VF, name)(*args, **kwargs)\n    return wrapper\n\n# Some python magic to generate an object that has the rnn cell functions\n# defined on it, all of which call into corresponding _VF version.\n# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named \"_VF\"\n# imported at module scope within torch.nn.modules.rnn).  This should\n# not affect third-party importers of _VF.py.\nclass VariableFunctionsShim(object):\n    def __init__(self):\n        for name in RNN_NAMES:\n            for suffix in ['', '_cell']:\n               fn_name = name + suffix\n               setattr(self, fn_name, _gen_VF_wrapper(fn_name))\n\ndef has_old_rnns():\n    try:\n        torch.nn.backends.thnn.backend.LSTMCell\n        return True\n    except:\n        return False\n\ndef whitelist_rnn_cells(handle, verbose):\n    # Different module + function names in old/new RNN cases\n    if has_old_rnns():\n        fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']\n        mod = torch.nn.backends.thnn.backend\n    else:\n        fn_names = [x + '_cell' for x in RNN_NAMES]\n        mod = torch.nn.modules.rnn._VF\n        assert isinstance(mod, VariableFunctionsShim)\n\n    # Insert casts on cell functions\n    for fn in fn_names:\n        wrap.cached_cast(mod, fn, utils.maybe_half, handle,\n                         try_caching=True, verbose=verbose)\n\n    if has_old_rnns():\n        # Special handling of `backward` for fused gru / lstm:\n        # The `backward` method calls Tensor.sum() (blacklist) internally,\n        # and then the resulting grad_input has the wrong type.\n        # TODO: where else is this a problem?\n        for rnn_type in ['GRUFused', 'LSTMFused']:\n            mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)\n            wrap.disable_casts(mod, 'backward', handle)\n"
  },
  {
    "path": "KoSimCSE/apex/amp/scaler.py",
    "content": "import torch\nfrom ..multi_tensor_apply import multi_tensor_applier\nfrom ._amp_state import _amp_state, master_params, maybe_print\nfrom itertools import product\n\ndef scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):\n    # Exception handling for 18.04 compatibility\n    if check_overflow:\n        cpu_sum = float(model_grad.float().sum())\n        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n            return True\n\n    if master_grad is not model_grad: # copy_ probably internally short-circuits this\n        master_grad.copy_(model_grad)\n    if scale != 1.0:\n        master_grad.mul_(scale)\n    return False\n\ndef axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):\n    # Exception handling for 18.04 compatibility\n    if check_overflow:\n        cpu_sum = float(model_grad.float().sum())\n        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n            return True\n\n    # if master_grad is not model_grad: # copy_ probably internally short-circuits this\n    #     master_grad.copy_(model_grad)\n    assert stashed_grad.dtype == master_grad.dtype\n    converted_model_grad = model_grad.data.to(master_grad.dtype)\n    master_grad.data = a*converted_model_grad.data + b*stashed_grad.data\n    return False\n\nclass LossScaler(object):\n    warned_no_fused_kernel = False\n    warned_unscaling_non_fp32_grad = False\n    has_fused_kernel = False\n\n    def __init__(self,\n                 loss_scale,\n                 init_scale=2.**16,\n                 scale_factor=2.,\n                 scale_window=2000,\n                 min_loss_scale=None,\n                 max_loss_scale=2.**24):\n        if loss_scale == \"dynamic\":\n            self.dynamic = True\n            self._loss_scale = min(max_loss_scale, init_scale)\n        else:\n            self.dynamic = False\n            self._loss_scale = loss_scale\n        self._max_loss_scale = max_loss_scale\n        self._min_loss_scale = min_loss_scale\n        self._scale_seq_len = scale_window\n        self._unskipped = 0\n        self._has_overflow = False\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        if multi_tensor_applier.available:\n            import amp_C\n            LossScaler.has_fused_kernel = multi_tensor_applier.available\n            LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale\n            LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby\n        else:\n            if not LossScaler.warned_no_fused_kernel:\n                maybe_print(\n                    \"Warning:  multi_tensor_applier fused unscale kernel is unavailable, \"\n                    \"possibly because apex was installed without --cuda_ext --cpp_ext. \"\n                    \"Using Python fallback.  Original ImportError was: \" +\n                    repr(multi_tensor_applier.import_err),\n                    True)\n            LossScaler.has_fused_kernel = False\n            LossScaler.warned_no_fused_kernel = True\n\n    def loss_scale(self):\n        return self._loss_scale\n\n    def unscale_python(self, model_grads, master_grads, scale):\n        for model, master in zip(model_grads, master_grads):\n            if model is not None:\n                if not LossScaler.warned_unscaling_non_fp32_grad:\n                    if master.dtype != torch.float32:\n                        maybe_print(\n                            \"Attempting to unscale a grad with type {} \".format(master.type()) +\n                            \"Unscaling non-fp32 grads may indicate an error. \"\n                            \"When using Amp, you don't need to call .half() on your model.\")\n                        LossScaler.warned_unscaling_non_fp32_grad = True\n                self._has_overflow = scale_check_overflow_python(model,\n                                                                 master,\n                                                                 1./scale,\n                                                                 self.dynamic)\n                if self._has_overflow and self.dynamic:\n                    break\n\n    # unused_scale keeps some of the old API alive for hopefully a short time.\n    def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):\n        if self._has_overflow:\n            return\n\n        scale = self._loss_scale\n        if scale_override is not None:\n            scale = scale_override\n\n        if scale == 1.0 and models_are_masters and not self.dynamic:\n            return\n\n        if LossScaler.has_fused_kernel:\n            # if (not LossScaler.warned_unscaling_non_fp32_grad\n            #     and master_grads[0].dtype == torch.float16):\n            #     print(\"Warning:  unscaling grads that are not FP32. \"\n            #           \"Unscaling non-fp32 grads may indicate an error. \"\n            #           \"When using Amp, you don't need to call .half() on your model.\")\n            #     # Setting this to True unconditionally allows the possibility of an escape\n            #     # if never-before-seen non-fp32 grads are created in some later iteration.\n            #     LossScaler.warned_unscaling_non_fp32_grad = True\n            multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,\n                                 self._overflow_buf,\n                                 [model_grads, master_grads],\n                                 1./scale)\n        else:\n            self.unscale_python(model_grads, master_grads, scale)\n\n        # Defer to update_scale\n        # If the fused kernel is available, we only need one D2H memcopy and sync.\n        # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:\n        #     self._has_overflow = self._overflow_buf.item()\n\n    def unscale_with_stashed_python(self,\n                                    model_grads,\n                                    stashed_master_grads,\n                                    master_grads,\n                                    a,\n                                    b):\n        for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):\n            if model is None and stashed is None:\n                continue\n            else:\n                if not LossScaler.warned_unscaling_non_fp32_grad:\n                    if master.dtype != torch.float32:\n                        maybe_print(\n                            \"Attempting to unscale a grad with type {} \".format(master.type()) +\n                            \"Unscaling non-fp32 grads may indicate an error. \"\n                            \"When using Amp, you don't need to call .half() on your model.\")\n                        LossScaler.warned_unscaling_non_fp32_grad = True\n                self._has_overflow = axpby_check_overflow_python(model,\n                                                                 stashed,\n                                                                 master,\n                                                                 a,\n                                                                 b,\n                                                                 self.dynamic)\n                if self._has_overflow and self.dynamic:\n                    break\n\n    def unscale_with_stashed(self,\n                             model_grads,\n                             stashed_master_grads,\n                             master_grads,\n                             scale_override=None):\n        if self._has_overflow:\n            return\n\n        grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0\n        if scale_override is not None:\n            grads_have_scale, stashed_have_scale, out_scale = scale_override\n\n        if LossScaler.has_fused_kernel:\n            if (not LossScaler.warned_unscaling_non_fp32_grad\n                and master_grads[0].dtype == torch.float16):\n                print(\"Warning:  unscaling grads that are not FP32. \"\n                      \"Unscaling non-fp32 grads may indicate an error. \"\n                      \"When using Amp, you don't need to call .half() on your model.\")\n                # Setting this to True unconditionally allows the possibility of an escape\n                # if never-before-seen non-fp32 grads are created in some later iteration.\n                LossScaler.warned_unscaling_non_fp32_grad = True\n            multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,\n                                 self._overflow_buf,\n                                 [model_grads, stashed_master_grads, master_grads],\n                                 out_scale/grads_have_scale,   # 1./scale,\n                                 out_scale/stashed_have_scale, # 1.0,\n                                 0) # check only arg 0, aka the incoming model grads, for infs\n        else:\n            self.unscale_with_stashed_python(model_grads,\n                                             stashed_master_grads,\n                                             master_grads,\n                                             out_scale/grads_have_scale,\n                                             out_scale/stashed_have_scale)\n\n        # Defer to update_scale\n        # If the fused kernel is available, we only need one D2H memcopy and sync.\n        # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:\n        #     self._has_overflow = self._overflow_buf.item()\n\n    def clear_overflow_state(self):\n        self._has_overflow = False\n        if self.has_fused_kernel:\n            self._overflow_buf.zero_()\n\n    # Separate so unscale() can be called more that once before updating.\n    def update_scale(self):\n        # If the fused kernel is available, we only need one D2H memcopy and sync.\n        if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:\n            self._has_overflow = self._overflow_buf.item()\n\n        if self._has_overflow and self.dynamic:\n            should_skip = True\n            if(self._min_loss_scale):\n                self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)\n            else:\n                self._loss_scale = self._loss_scale/2.\n            self._unskipped = 0\n        else:\n            should_skip = False\n            self._unskipped += 1\n\n        if self._unskipped == self._scale_seq_len and self.dynamic:\n            self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)\n            self._unskipped = 0\n\n        return should_skip\n"
  },
  {
    "path": "KoSimCSE/apex/amp/utils.py",
    "content": "from . import compat\n\nimport functools\nimport itertools\n\nimport torch\n\ndef is_cuda_enabled():\n    return torch.version.cuda is not None\n\ndef get_cuda_version():\n    return tuple(int(x) for x in torch.version.cuda.split('.'))\n\ndef is_fp_tensor(x):\n    if is_nested(x):\n        # Fast-fail version of all(is_fp_tensor)\n        for y in x:\n            if not is_fp_tensor(y):\n                return False\n        return True\n    return compat.is_tensor_like(x) and compat.is_floating_point(x)\n\ndef is_nested(x):\n    return isinstance(x, tuple) or isinstance(x, list)\n\ndef should_cache(x):\n    if is_nested(x):\n        # Fast-fail version of all(should_cache)\n        for y in x:\n            if not should_cache(y):\n                return False\n        return True\n    return isinstance(x, torch.nn.parameter.Parameter) and \\\n        type_string(x) == 'FloatTensor'\n\ndef collect_fp_tensor_types(args, kwargs):\n    def collect_types(x, types):\n        if is_nested(x):\n            for y in x:\n                collect_types(y, types)\n        else:\n            types.add(type_string(x))\n\n    all_args = itertools.chain(args, kwargs.values())\n    types = set()\n    for x in all_args:\n        if is_fp_tensor(x):\n            collect_types(x, types)\n    return types\n\ndef type_string(x):\n    return x.type().split('.')[-1]\n\ndef maybe_half(x, name='', verbose=False):\n    if is_nested(x):\n        return type(x)([maybe_half(y) for y in x])\n\n    if not x.is_cuda or type_string(x) == 'HalfTensor':\n        return x\n    else:\n        if verbose:\n            print('Float->Half ({})'.format(name))\n        return x.half()\n\ndef maybe_float(x, name='', verbose=False):\n    if is_nested(x):\n        return type(x)([maybe_float(y) for y in x])\n\n    if not x.is_cuda or type_string(x) == 'FloatTensor':\n        return x\n    else:\n        if verbose:\n            print('Half->Float ({})'.format(name))\n        return x.float()\n\n# NB: returneds casted `args`, mutates `kwargs` in-place\ndef casted_args(cast_fn, args, kwargs):\n    new_args = []\n    for x in args:\n        if is_fp_tensor(x):\n            new_args.append(cast_fn(x))\n        else:\n            new_args.append(x)\n    for k in kwargs:\n        val = kwargs[k]\n        if is_fp_tensor(val):\n            kwargs[k] = cast_fn(val)\n    return new_args\n\ndef cached_cast(cast_fn, x, cache):\n    if is_nested(x):\n        return type(x)([cached_cast(y) for y in x])\n    if x in cache:\n        cached_x = cache[x]\n        if x.requires_grad and cached_x.requires_grad:\n            # Make sure x is actually cached_x's autograd parent.\n            if cached_x.grad_fn.next_functions[1][0].variable is not x:\n                raise RuntimeError(\"x and cache[x] both require grad, but x is not \"\n                                   \"cache[x]'s parent.  This is likely an error.\")\n        # During eval, it's possible to end up caching casted weights with\n        # requires_grad=False.  On the next training iter, if cached_x is found\n        # and reused from the cache, it will not actually have x as its parent.\n        # Therefore, we choose to invalidate the cache (and force refreshing the cast)\n        # if x.requires_grad and cached_x.requires_grad do not match.\n        #\n        # During eval (i.e. running under with torch.no_grad()) the invalidation\n        # check would cause the cached value to be dropped every time, because\n        # cached_x would always be created with requires_grad=False, while x would\n        # still have requires_grad=True.  This would render the cache effectively\n        # useless during eval.  Therefore, if we are running under the no_grad()\n        # context manager (torch.is_grad_enabled=False) we elide the invalidation\n        # check, and use the cached value even though its requires_grad flag doesn't\n        # match.  During eval, we don't care that there's no autograd-graph\n        # connection between x and cached_x.\n        if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:\n            del cache[x]\n        else:\n            return cached_x\n\n    casted_x = cast_fn(x)\n    cache[x] = casted_x\n    return casted_x\n\ndef verbosify(cast_fn, fn_name, verbose):\n    if verbose:\n        return functools.partial(cast_fn, name=fn_name, verbose=verbose)\n    else:\n        return cast_fn\n\ndef as_inplace(fns):\n    for x in fns:\n        yield x + '_'\n\ndef has_func(mod, fn):\n    if isinstance(mod, dict):\n        return fn in mod\n    else:\n        return hasattr(mod, fn)\n\ndef get_func(mod, fn):\n    if isinstance(mod, dict):\n        return mod[fn]\n    else:\n        return getattr(mod, fn)\n\ndef set_func(mod, fn, new_fn):\n    if isinstance(mod, dict):\n        mod[fn] = new_fn\n    else:\n        setattr(mod, fn, new_fn)\n\ndef set_func_save(handle, mod, fn, new_fn):\n    cur_fn = get_func(mod, fn)\n    handle._save_func(mod, fn, cur_fn)\n    set_func(mod, fn, new_fn)\n\n# A couple problems get solved here:\n# - The flat_weight buffer is disconnected from autograd graph,\n#   so the fp16 weights need to be derived from the input weights\n#   to this forward call, not the flat buffer.\n# - The ordering of weights in the flat buffer is...idiosyncratic.\n# First problem is solved with combination of set_ (to set up\n# correct storage) and copy_ (so the fp16 weight derives from the\n# fp32 one in autograd.\n# Second is solved by doing ptr arithmetic on the fp32 weights\n# to derive the correct offset.\n#\n# TODO: maybe this should actually use\n# `torch._cudnn_rnn_flatten_weight`? But then I need to call\n# on first iter and cache the right offsets. Ugh.\ndef synthesize_flattened_rnn_weights(fp32_weights,\n                                     fp16_flat_tensor,\n                                     rnn_fn='',\n                                     verbose=False):\n    fp16_weights = []\n    fp32_base_ptr = fp32_weights[0][0].data_ptr()\n    for layer_weights in fp32_weights:\n        fp16_layer_weights = []\n        for w_fp32 in layer_weights:\n            w_fp16 = w_fp32.new().half()\n            offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()\n            w_fp16.set_(fp16_flat_tensor.storage(),\n                        offset,\n                        w_fp32.shape)\n            w_fp16.copy_(w_fp32)\n            if verbose:\n                print('Float->Half ({})'.format(rnn_fn))\n            fp16_layer_weights.append(w_fp16)\n        fp16_weights.append(fp16_layer_weights)\n    return fp16_weights\n\n# Roughly same as above, just the `fp32_weights` aren't nested.\n# Code kept separate for readability.\ndef new_synthesize_flattened_rnn_weights(fp32_weights,\n                                         fp16_flat_tensor,\n                                         rnn_fn='',\n                                         verbose=False):\n    fp16_weights = []\n    fp32_base_ptr = fp32_weights[0].data_ptr()\n    for w_fp32 in fp32_weights:\n        w_fp16 = w_fp32.new().half()\n        offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()\n        w_fp16.set_(fp16_flat_tensor.storage(),\n                    offset,\n                    w_fp32.shape)\n        w_fp16.copy_(w_fp32)\n        if verbose:\n            print('Float->Half ({})'.format(rnn_fn))\n        fp16_weights.append(w_fp16)\n    return fp16_weights\n"
  },
  {
    "path": "KoSimCSE/apex/amp/wrap.py",
    "content": "from . import compat\nfrom . import utils\nfrom ._amp_state import _amp_state\nfrom . import rnn_compat\n\nimport functools\n\nimport torch\n\ndef make_cast_wrapper(orig_fn, cast_fn, handle,\n                      try_caching=False):\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        if not handle.is_active():\n            return orig_fn(*args, **kwargs)\n\n        if try_caching and handle.has_cache:\n            args = list(args)\n            for i in range(len(args)):\n                if utils.should_cache(args[i]):\n                    args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)\n            for k in kwargs:\n                if utils.should_cache(kwargs[k]):\n                    kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)\n        new_args = utils.casted_args(cast_fn,\n                                     args,\n                                     kwargs)\n        return orig_fn(*new_args, **kwargs)\n    return wrapper\n\ndef cached_cast(mod, fn, cast_fn, handle,\n                try_caching=False, verbose=False):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    cast_fn = utils.verbosify(cast_fn, fn, verbose)\n    wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\n# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`\n# Annoyingly, make_promote_wrapper still uses the global handle.  Once everyone\n# is on the new API and I am free to get rid of handle, I can clean this up.\ndef make_promote_wrapper(orig_fn, cast_fn, handle=None):\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        if not _amp_state.handle.is_active():\n            return orig_fn(*args, **kwargs)\n\n        types = utils.collect_fp_tensor_types(args, kwargs)\n\n        if len(types) <= 1:\n            return orig_fn(*args, **kwargs)\n        elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):\n            new_args = utils.casted_args(cast_fn,\n                                         args,\n                                         kwargs)\n            return orig_fn(*new_args, **kwargs)\n        else:\n            raise NotImplementedError('Do not know how to handle ' +\n                                      'these types to promote: {}'\n                                      .format(types))\n    return wrapper\n\ndef promote(mod, fn, handle, verbose=False):\n    orig_fn = utils.get_func(mod, fn)\n    maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)\n    wrapper = make_promote_wrapper(orig_fn, maybe_float)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef sequence_promote(mod, fn, handle, verbose=False):\n    orig_fn = utils.get_func(mod, fn)\n    maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)\n    @functools.wraps(orig_fn)\n    def wrapper(seq, *args, **kwargs):\n        if not _amp_state.handle.is_active():\n            return orig_fn(seq, *args, **kwargs)\n\n        types = set([utils.type_string(x) for x in seq])\n        if len(types) <= 1:\n            return orig_fn(seq, *args, **kwargs)\n        elif types == set(['HalfTensor', 'FloatTensor']):\n            cast_seq = utils.casted_args(maybe_float,\n                                         seq, {})\n            return orig_fn(cast_seq, *args, **kwargs)\n        else:\n            # TODO: other mixed-type cases aren't due to amp.\n            #       Just pass through?\n            return orig_fn(seq, *args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef promote_match_arg0(mod, fn, handle, verbose=False):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(arg0, *args, **kwargs):\n        assert compat.is_tensor_like(arg0)\n        if not _amp_state.handle.is_active():\n            return orig_fn(arg0, *args, **kwargs)\n\n        if utils.type_string(arg0) == 'HalfTensor':\n            cast_fn = utils.maybe_half\n        elif utils.type_string(arg0) == 'FloatTensor':\n            cast_fn = utils.maybe_float\n        else:\n            return orig_fn(arg0, *args, **kwargs)\n        cast_fn = utils.verbosify(cast_fn, fn, verbose)\n        new_args = utils.casted_args(cast_fn, args, kwargs)\n        return orig_fn(arg0, *new_args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef err_if_any_half(mod, fn, handle, custom_err_msg=None):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        types = utils.collect_fp_tensor_types(args, kwargs)\n        if 'HalfTensor' in types:\n            if custom_err_msg:\n                raise NotImplementedError(custom_err_msg)\n            else:\n                raise NotImplementedError('Cannot call in-place function ' +\n                                          '{} with fp16 arguments.'.format(fn))\n        else:\n            return orig_fn(*args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef err_if_arg0_half(mod, fn, handle, verbose=False):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(arg0, *args, **kwargs):\n        assert compat.is_tensor_like(arg0)\n        if utils.type_string(arg0) == 'HalfTensor':\n            raise NotImplementedError('Cannot call in-place method ' +\n                                      '{} on fp16 Tensors.'.format(fn))\n        else:\n            cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)\n            new_args = utils.casted_args(cast_fn, args, kwargs)\n            return orig_fn(arg0, *new_args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\n# Current RNN approach:\n# - Wrap top-level `RNN` function in thnn backend\n# - Will call into either CudnnRNN or AutogradRNN\n#  - Each of these are factory functions that return a per-iter\n#    `forward` function\n# - We interpose on the factory function to:\n#   1) Interpose on the actual forward function and put in casts\n#   2) Insert an fp16 `flat_weight` if necessary\ndef rnn_cast(backend, fn, handle, verbose=False):\n    orig_rnn = utils.get_func(backend, fn)\n    @functools.wraps(orig_rnn)\n    def rnn_wrapper(*args, **kwargs):\n        flat_weight = kwargs.get('flat_weight')\n        if flat_weight is not None:\n            # We replace `flat_weight` with an uninitialized fp16\n            # Tensor. The \"actual\" weight tensors (provided in `forward`),\n            # will then be set up as ptrs into the buffer and have the\n            # corresponding fp32 values copied in.\n            # We need to call `copy` on the \"actual\" weights so that the\n            # autograd graph correctly backprops from the wgrads computed\n            # inside cuDNN (on fp16 weights) into the fp32 weights.\n            assert utils.type_string(flat_weight) == 'FloatTensor'\n            if compat.tensor_is_float_tensor() or compat.tensor_is_variable():\n                # Pre-0.4. A little slower, since it zeros out memory.\n                flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)\n            else:\n                flat_weight_fp16 = torch.empty_like(flat_weight,\n                                                    dtype=torch.float16)\n            kwargs['flat_weight'] = flat_weight_fp16\n        else:\n            flat_weight_fp16 = None\n\n        forward = orig_rnn(*args, **kwargs)\n        @functools.wraps(forward)\n        def fwd_wrapper(*fargs, **fkwargs):\n            assert len(fargs) == 3 or len(fargs) == 4\n            inputs, weights, hiddens = fargs[:3]\n            assert utils.is_fp_tensor(inputs)\n            assert isinstance(weights, list)\n            cast_fn = utils.verbosify(utils.maybe_half,\n                                      fn,\n                                      verbose)\n            new_args = []\n\n            # 0) Inputs\n            new_args.append(cast_fn(inputs))\n\n            # 1) Weights\n            if flat_weight_fp16 is not None:\n                fp16_weights = utils.synthesize_flattened_rnn_weights(\n                    weights, flat_weight_fp16, fn, verbose)\n            else:\n                fp16_weights = [[cast_fn(w) for w in layer]\n                                for layer in weights]\n            new_args.append(fp16_weights)\n\n            # 2) Inputs: either a tuple (for LSTM) or single tensor\n            if isinstance(hiddens, tuple):\n                new_args.append(tuple(cast_fn(x) for x in hiddens))\n            elif utils.is_fp_tensor(hiddens):\n                new_args.append(cast_fn(hiddens))\n            else:\n                # Hiddens can, in principle, be `None` -- pass through\n                new_args.append(hiddens)\n\n            # 3) Batch sizes (0.4 or later only)\n            if len(fargs) == 4:\n                new_args.append(fargs[3])\n\n            return forward(*new_args, **fkwargs)\n        return fwd_wrapper\n    utils.set_func_save(handle, backend, fn, rnn_wrapper)\n\ndef new_rnn_cast(fn, handle, verbose=False):\n    # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744\n    # For rnn backend calls that route through _rnn_impls, we must patch the ref\n    # that _rnn_impls stashed.  For rnn backend calls that directly invoke\n    # _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,\n    # which in turn has patched the ref named \"_VF\" in torch.nn.modules.rnn.\n    if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):\n        mod = torch.nn.modules.rnn._rnn_impls\n    else:\n        mod = torch.nn.modules.rnn._VF\n        assert isinstance(mod, rnn_compat.VariableFunctionsShim)\n        fn = fn.lower()\n    orig_fn = utils.get_func(mod, fn)\n    cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        # Exact call signature from modules/rnn.py\n        assert len(args) == 9\n        assert len(kwargs) == 0\n\n        if not _amp_state.handle.is_active():\n            return orig_fn(*args, **kwargs)\n\n        if isinstance(args[6], bool):\n            params_idx = 2 # Not PackedSequence case\n        else:\n            params_idx = 3 # PackedSequence case\n\n        new_args = []\n        for i, arg in enumerate(args):\n            if i == params_idx:\n                num_params = sum([x.numel() for x in arg])\n                fp16_weight_buf = args[0].new_empty((num_params,),\n                                                    dtype=torch.half)\n                casted_weights = utils.new_synthesize_flattened_rnn_weights(\n                    arg, fp16_weight_buf, fn, verbose)\n                new_args.append(casted_weights)\n            elif utils.is_fp_tensor(arg):\n                new_args.append(cast_fn(arg))\n            else:\n                new_args.append(arg)\n\n        return orig_fn(*new_args)\n    utils.set_func_save(handle, mod, fn, wrapper)\n\ndef disable_casts(mod, fn, handle):\n    if not utils.has_func(mod, fn):\n        return\n\n    orig_fn = utils.get_func(mod, fn)\n    @functools.wraps(orig_fn)\n    def wrapper(*args, **kwargs):\n        with handle._disable_casts():\n            return orig_fn(*args, **kwargs)\n    utils.set_func_save(handle, mod, fn, wrapper)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/__init__.py",
    "content": ""
  },
  {
    "path": "KoSimCSE/apex/contrib/bottleneck/__init__.py",
    "content": "from .bottleneck import Bottleneck\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/bottleneck/bottleneck.py",
    "content": "import torch\nfrom torch import nn\nimport fast_bottleneck\n\ndef kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):\n    weight_tensor_nchw = tensor\n    nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)\n\nclass FrozenBatchNorm2d(torch.nn.Module):\n    \"\"\"\n    BatchNorm2d where the batch statistics and the affine parameters are fixed\n    \"\"\"\n    def __init__(self, n):\n        super(FrozenBatchNorm2d, self).__init__()\n        self.register_buffer(\"weight\", torch.ones(n))\n        self.register_buffer(\"bias\", torch.zeros(n))\n        self.register_buffer(\"running_mean\", torch.zeros(n))\n        self.register_buffer(\"running_var\", torch.ones(n))\n\n    def get_scale_bias(self, nhwc=False):\n        scale = self.weight * self.running_var.rsqrt()\n        bias = self.bias - self.running_mean * scale\n        if nhwc:\n            scale = scale.reshape(1, 1, 1, -1)\n            bias = bias.reshape(1, 1, 1, -1)\n        else:\n            scale = scale.reshape(1, -1, 1, 1)\n            bias = bias.reshape(1, -1, 1, 1)\n        return scale, bias\n\n    def forward(self, x):\n        scale, bias = self.get_scale_bias()\n        return x * scale + bias\n\n\n@torch.jit.script\ndef drelu_dscale1(grad_o, output, scale1):\n    relu_mask = (output>0).half()\n    dx_relu = relu_mask * grad_o\n    g1 = dx_relu * scale1\n    return g1, dx_relu\n\n@torch.jit.script\ndef drelu_dscale2(grad_o, output, scale1, scale2):\n    relu_mask = (output>0).half()\n    dx_relu = relu_mask * grad_o\n    g1 = dx_relu * scale1\n    g2 = dx_relu * scale2\n    return g1, g2\n\nclass BottleneckFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):\n        # TODO: clean up order of tensors\n        args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]\n        ctx.downsample = len(conv) > 3\n        if ctx.downsample:\n            args.append(conv[3])\n            args.append(scale[3])\n            args.append(bias[3])\n\n        # weight buffers are always in nhwc while shape can be nhwc or channels_last\n        # here we pass in flag and let c++ handle it\n        # alternatively, we can put all sizes into a fixed format and pass it in\n        outputs = fast_bottleneck.forward(nhwc, stride_1x1, args)\n        ctx.save_for_backward(*(args+outputs))\n        # save relu outputs for drelu\n        ctx.nhwc = nhwc\n        ctx.stride_1x1 = stride_1x1\n        return outputs[2]\n\n    # backward relu is not exposed, MUL with mask used now\n    # only support dgrad\n    @staticmethod\n    def backward(ctx, grad_o):\n        outputs = ctx.saved_tensors[-3:]\n\n        if ctx.downsample:\n            grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])\n        else:\n            grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])\n\n        # create input vector for backward\n        t_list = [*ctx.saved_tensors[0:10]]\n        t_list.append(grad_conv3)\n        t_list.append(grad_conv4)\n\n        # outputs used for wgrad and generating drelu mask\n        t_list.append(outputs[0])\n        t_list.append(outputs[1])\n\n        # in case there is downsample\n        if ctx.downsample:\n            t_list.append(ctx.saved_tensors[10])\n\n        grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list)\n\n        return (None, None, None, None, *grads)\n\nbottleneck_function = BottleneckFunction.apply\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\nclass Bottleneck(torch.nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n    # here we put it at 1x1\n\n    def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,\n                 dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False):\n        super(Bottleneck, self).__init__()\n        if groups != 1:\n            raise RuntimeError('Only support groups == 1')\n        if dilation != 1:\n            raise RuntimeError('Only support dilation == 1')\n        if norm_func == None:\n            norm_func = FrozenBatchNorm2d\n        else:\n            raise RuntimeError('Only support frozen BN now.')\n\n        if stride != 1 or in_channels != out_channels:\n            self.downsample = nn.Sequential(\n                conv1x1(in_channels, out_channels, stride),\n                norm_func(out_channels),\n            )\n        else:\n            self.downsample = None\n\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)\n        self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)\n        self.conv3 = conv1x1(bottleneck_channels, out_channels)\n        self.relu = nn.ReLU(inplace=True)\n        self.stride = stride\n\n        self.bn1 = norm_func(bottleneck_channels)\n        self.bn2 = norm_func(bottleneck_channels)\n        self.bn3 = norm_func(out_channels)\n\n        self.use_cudnn = use_cudnn\n\n        # setup conv weights\n        self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]\n        if self.downsample is not None:\n            self.w_conv.append(self.downsample[0].weight)\n\n        # init weight in nchw format before possible transpose\n        for w in self.w_conv:\n            kaiming_uniform_(w, a=1)\n\n        # TODO: prevent unsupported case usage\n        # support cases\n        #                 native      cudnn\n        # normal             yes         no\n        # channel_last       yes        yes\n        # explicit_nhwc       no        yes\n        self.explicit_nhwc = explicit_nhwc\n        if self.explicit_nhwc:\n            for p in self.parameters():\n                with torch.no_grad():\n                    p.data = p.data.permute(0,2,3,1).contiguous()\n        return\n\n    def forward(self, x):\n        if self.use_cudnn:\n            # calculate scale/bias from registered buffers\n            # TODO: make this better\n            s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)\n            s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)\n            s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)\n            w_scale = [s1, s2, s3]\n            w_bias = [b1, b2, b3]\n            if self.downsample is not None:\n                s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)\n                w_scale.append(s4)\n                w_bias.append(b4)\n\n            out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)\n            return out\n\n        if self.explicit_nhwc:\n            raise RuntimeError('explicit nhwc with native ops is not supported.')\n\n        # fallback to native ops\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/bottleneck/test.py",
    "content": "import torch\nfrom bottleneck import Bottleneck\ntorch.manual_seed(23337)\n\n# use True to print layerwise sum for all outputs in reference code path\nDEBUG = False#True\n\nfor stride, o_channel in [(1,32), (1,128), (2,32)]:\n    print(\"testing stride ==\", stride, \", in_channel == 32 , out_channel ==\", o_channel)\n    a_ = torch.randn(17,32,28,28)\n\n    a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()\n    model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last)\n\n    # test model\n    b = model(a)\n    b.mean().backward()\n    d_grad = a.grad.float()\n    a.grad = None\n    torch.cuda.synchronize()\n\n    if DEBUG:\n        print(\"[DEBUG] ref dx :\", d_grad.sum().item())\n        # print wgrad. we don't need to reset since later cpp print before accumulation\n        for i, w in enumerate(model.w_conv):\n            print(\"[DEBUG] ref wgrad{} :\".format(i+1), w.grad.sum().item())\n\n    wgrads = []\n    for w in model.w_conv:\n        wgrads.append(w.grad.float())\n\n    model.use_cudnn = True\n    model.zero_grad()\n    c = model(a)\n    c.mean().backward()\n\n    torch.cuda.synchronize()\n    print(\"comparing native and channels_last:\")\n    print(\"max error fprop:\", (b-c).abs().max().item(), \"max elem:\", b.abs().max().item())\n    print(\"max error dgrad:\", (d_grad-a.grad.float()).abs().max().item(), \"max elem:\", d_grad.abs().max().item())\n    for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):\n        print(\"max error wgrad{}:\".format(i+1), (wgrad - w.grad.float()).abs().max().item(), \"max elem:\", wgrad.abs().max().item())\n\n    nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_()\n    nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half()\n    for p,q in zip(model.parameters(), nhwc_model.parameters()):\n        # model's storage is already in nhwc, we clone and assign to explicit nhwc model\n        q.data.copy_(p.data.permute(0,2,3,1).contiguous())\n    for p,q in zip(model.buffers(), nhwc_model.buffers()):\n        q.data.copy_(p.data)\n\n    d = nhwc_model(nhwc_a)\n    d.mean().backward()\n    torch.cuda.synchronize()\n\n    # reset reference to cudnn channels_last permute\n    #c_s = c.storage().tolist()\n    #d_s = d.storage().tolist()\n    #print(max([x-y for x,y in zip(c_s,d_s)]))\n    c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous()\n    d_grad = a.grad.float().permute(0,2,3,1).contiguous()\n    wgrads = []\n    for w in model.w_conv:\n        wgrads.append(w.grad.float().permute(0,2,3,1).contiguous())\n\n    torch.cuda.synchronize()\n    print(\"comparing nhwc and channels_last:\")\n    print(\"max error fprop:\", (d-c).abs().max().item(), \"max elem:\", c.abs().max().item())\n    print(\"max error dgrad:\", (d_grad-nhwc_a.grad.float()).abs().max().item(), \"max elem:\", d_grad.abs().max().item())\n    for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):\n        print(\"max error wgrad{}:\".format(i+1), (wgrad - w.grad.float()).abs().max().item(), \"max elem:\", wgrad.abs().max().item())\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/bottleneck/bottleneck.cpp",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cudnn/Handle.h>  // for getcudnnhandle\n#include <torch/extension.h>\n#include <torch/torch.h>\n#include <vector>\n#include <cudnn_frontend.h>\n\n#include <iostream>\n\n#ifdef DEBUG\n#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )\n#else\n#define DEBUG_MSG(str) do { } while ( false )\n#endif\n\n#ifdef DEBUG_CUDNN\n#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )\n#else\n#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )\n#endif\n\n#define checkCudnnErr(...)                                                        \\\n    do {                                                                          \\\n        int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \\\n        if (err) {                                                                \\\n            return;                                                    \\\n\t}                                                                         \\\n    } while (0)\n\n\nint checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {\n    if (code) {\n        printf(\"CUDNN error at %s:%d, code=%d (%s) in '%s'\\n\", file, line, (int)code, cudnnGetErrorString(code), expr);\n        return 1;\n    }\n    return 0;\n}\n\nvoid checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true);\n#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); }    // in-line regular function\n\nvoid checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort)\n{\n  if (code != cudaSuccess)\n  {\n    const char * errorMessage = cudaGetErrorString(code);\n    fprintf(stderr, \"CUDA error returned from \\\"%s\\\" at %s:%d, Error code: %d (%s)\\n\", func, file, line, code, errorMessage);\n    if (abort){\n      cudaDeviceReset();\n      exit(code);\n    }\n  }\n}\n\nvoid generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {\n    // For INT8x4 and INT8x32 we still compute standard strides here to input\n    // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.\n    if (filterFormat == CUDNN_TENSOR_NCHW) {\n        strideA[nbDims - 1] = 1;\n        for (int64_t d = nbDims - 2; d >= 0; d--) {\n            strideA[d] = strideA[d + 1] * dimA[d + 1];\n        }\n    } else {\n        // Here we assume that the format is CUDNN_TENSOR_NHWC\n\tstrideA[1]          = 1;\n        strideA[nbDims - 1] = strideA[1] * dimA[1];\n        for (int64_t d = nbDims - 2; d >= 2; d--) {\n            strideA[d] = strideA[d + 1] * dimA[d + 1];\n        }\n        strideA[0] = strideA[2] * dimA[2];\n    }\n}\n\n\nint getFwdConvDilatedFilterDim(int filterDim, int dilation) {\n    return ((filterDim - 1) * dilation) + 1;\n}\n\nint getFwdConvPaddedImageDim(int tensorDim, int pad) {\n    return tensorDim + (2 * pad);\n}\n\nint getFwdConvOutputDim(\n    int tensorDim,\n    int pad,\n    int filterDim,\n    int stride,\n    int dilation)\n{\n    int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;\n    return (p);\n}\n\nenum {\n    X_TENSOR,\n    Y_TENSOR,\n    W_TENSOR,\n    Z_TENSOR,\n    B_TENSOR,\n    AFTERADD_TENSOR,\n    AFTERBIAS_TENSOR,\n    AFTERCONV_TENSOR,\n    OPTIONAL,\n    AFTEROPT_TENSOR,\n};\n\nusing common_conv_descriptors =\n    std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::ConvDesc>;\n\n\ncommon_conv_descriptors\ncreate_common_descriptors(int64_t* x_dim_padded,\n                          int64_t* padA,\n                          int64_t* convstrideA,\n                          int64_t* dilationA,\n                          int64_t* w_dim_padded,\n                          int64_t* y_dim_padded,\n                          cudnnDataType_t dataType,\n                          cudnnConvolutionMode_t mode) {\n    const int convDim = 2;\n\n    int64_t strideA_padded[4];\n    int64_t outstrideA_padded[4];\n    int64_t filterstrideA_padded[4];\n\n    generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC);\n\n    return common_conv_descriptors(cudnn_frontend::TensorBuilder()\n                                       .setDim(4, x_dim_padded)\n                                       .setStrides(4, strideA_padded)\n                                       .setId('x')\n                                       .setAlignment(16)\n                                       .setDataType(dataType)\n                                       .build(),\n                                   cudnn_frontend::TensorBuilder()\n                                       .setDim(4, y_dim_padded)\n                                       .setStrides(4, outstrideA_padded)\n                                       .setId('y')\n                                       .setAlignment(16)\n                                       .setDataType(dataType)\n                                       .build(),\n                                   cudnn_frontend::TensorBuilder()\n                                       .setDim(4, w_dim_padded)\n                                       .setStrides(4, filterstrideA_padded)\n                                       .setId('w')\n                                       .setAlignment(16)\n                                       .setDataType(dataType)\n                                       .build(),\n                                   cudnn_frontend::ConvDescBuilder()\n                                       .setDataType(CUDNN_DATA_FLOAT)\n                                       .setMathMode(mode)\n                                       .setNDims(convDim)\n                                       .setStrides(convDim, convstrideA)\n                                       .setPrePadding(convDim, padA)\n                                       .setPostPadding(convDim, padA)\n                                       .setDilation(convDim, dilationA)\n                                       .build());\n}\n\nusing common_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor,\n                                               cudnn_frontend::Tensor>;\n\ncommon_convbias_descriptors\ncreate_conv_bias_add_act_descriptors(int64_t* x_dim_padded,\n                                     int64_t* padA,\n                                     int64_t* convstrideA,\n                                     int64_t* dilationA,\n                                     int64_t* w_dim_padded,\n                                     int64_t* y_dim_padded,\n                                     cudnnDataType_t dataType) {\n    const int convDim = 2;\n\n    int64_t b_dim_padded[4];\n    b_dim_padded[0] = 1;\n    b_dim_padded[1] = y_dim_padded[1];\n    b_dim_padded[2] = 1;\n    b_dim_padded[3] = 1;\n\n    int64_t x_stride_padded[4];\n    int64_t y_stride_padded[4];\n    int64_t w_stride_padded[4];\n    int64_t b_stride_padded[4];\n\n    generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n\n    return common_convbias_descriptors(cudnn_frontend::TensorBuilder()\n                                           .setDim(4, x_dim_padded)\n                                           .setStrides(4, x_stride_padded)\n                                           .setId('x')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('y')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, w_dim_padded)\n                                           .setStrides(4, w_stride_padded)\n                                           .setId('w')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, b_dim_padded)\n                                           .setStrides(4, b_stride_padded)\n                                           .setId('z')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, b_dim_padded)\n                                           .setStrides(4, b_stride_padded)\n                                           .setId('b')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setVirtual()\n                                           .setId('A')  // after add\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setVirtual()\n                                           .setId('B')  // after bias\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('C')  // after conv\n                                           .setAlignment(16)\n                                           .setVirtual()\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('i')\n                                           .setAlignment(16)\n                                           .setDataType(dataType)\n                                           .build(),\n                                       cudnn_frontend::TensorBuilder()\n                                           .setDim(4, y_dim_padded)\n                                           .setStrides(4, y_stride_padded)\n                                           .setId('D')  // after optional add\n                                           .setAlignment(16)\n                                           .setVirtual()\n                                           .setDataType(dataType)\n                                           .build());\n}\n\n// tensor descriptors used for dgrad\nenum {\n    X_OR_DX_TENSOR,\n    DY_TENSOR,\n    W_OR_DW_TENSOR,\n    SCALE_TENSOR,\n    RELU_TENSOR,\n    AFTER_DCONV_TENSOR,\n    AFTER_DRELU_TENSOR,\n};\n\nusing dconv_descriptors = std::tuple<cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor,\n                                     cudnn_frontend::Tensor>;\n\ndconv_descriptors\ncreate_dconv_descriptors(int64_t* x_dim_padded,\n                         int64_t* padA,\n                         int64_t* convstrideA,\n                         int64_t* dilationA,\n                         int64_t* w_dim_padded,\n                         int64_t* y_dim_padded,\n                         cudnnDataType_t dataType) {\n    const int convDim = 2;\n\n    int64_t b_dim_padded[4];\n    b_dim_padded[0] = 1;\n    b_dim_padded[1] = x_dim_padded[1];\n    b_dim_padded[2] = 1;\n    b_dim_padded[3] = 1;\n\n    int64_t x_stride_padded[4];\n    int64_t y_stride_padded[4];\n    int64_t w_stride_padded[4];\n    int64_t b_stride_padded[4];\n\n    generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n    generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n\n    return dconv_descriptors(cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setId('x')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, y_dim_padded)\n                             .setStrides(4, y_stride_padded)\n                             .setId('y')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, w_dim_padded)\n                             .setStrides(4, w_stride_padded)\n                             .setId('w')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, b_dim_padded)\n                             .setStrides(4, b_stride_padded)\n                             .setId('s')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setId('r')\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setVirtual()\n                             .setId('A')  // after dconv\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build(),\n                             cudnn_frontend::TensorBuilder()\n                             .setDim(4, x_dim_padded)\n                             .setStrides(4, x_stride_padded)\n                             .setVirtual()\n                             .setId('B')  // after drelu\n                             .setAlignment(16)\n                             .setDataType(dataType)\n                             .build());\n}\n\n// create a cache for plan\nstd::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;\n\n// TODO: better name\nstd::string getConvFusionString(int64_t* x_dim_padded,\n                                int64_t* padA,\n                                int64_t* convstrideA,\n                                int64_t* dilationA,\n                                int64_t* w_dim_padded,\n                                cudnnDataType_t dataType,\n                                std::string fusion_string) {\n\n  for(int i=0;i<4;i++) {\n    fusion_string += 'X';\n    fusion_string += std::to_string(x_dim_padded[i]);\n  }\n  for(int i=0;i<4;i++) {\n    fusion_string += 'W';\n    fusion_string += std::to_string(w_dim_padded[i]);\n  }\n  for(int i=0;i<2;i++) {\n    fusion_string += 'P';\n    fusion_string += std::to_string(padA[i]);\n  }\n  for(int i=0;i<2;i++) {\n    fusion_string += 'S';\n    fusion_string += std::to_string(convstrideA[i]);\n  }\n  for(int i=0;i<2;i++) {\n    fusion_string += 'D';\n    fusion_string += std::to_string(dilationA[i]);\n  }\n  fusion_string += 'T';\n  fusion_string += std::to_string(dataType);\n  return fusion_string;\n}\n\ncudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_,\n                                               std::stringstream& log_buf,\n                                               cudnn_frontend::OperationGraph& opGraph,\n                                               std::string cache_string,\n                                               bool use_heuristic = true){\n  auto it = plan_cache.find(cache_string);\n  if (it != plan_cache.end()) {\n    DEBUG_CUDNN_MSG(log_buf, \"Found plan in cache\");\n    return it->second;\n  } else {\n    if (use_heuristic){\n      // TODO: confirm which mode to use\n      auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()\n        .setOperationGraph(opGraph)\n        .setHeurMode(CUDNN_HEUR_MODE_INSTANT)\n        .build();\n      // try 3 times for now as WAR for no heuristic training\n      int max_tries = 3, count = 0;\n      auto& engine_configs = heuristics.getEngineConfig(max_tries);\n      while(true) {\n        try {\n          plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()\n                                                     .setHandle(handle_)\n                                                     .setEngineConfig(engine_configs[count], opGraph.getTag())\n                                                     .build()));\n          break;\n        } catch (cudnn_frontend::cudnnException e) {\n          if (++count == max_tries) throw e;\n        }\n      }\n    }else{\n    DEBUG_CUDNN_MSG(log_buf, \"No plan in cache\");\n    // How many engines support this operation graph ?\n    auto total_engines = opGraph.getEngineCount();\n    DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << \" has \" << total_engines << \" engines.\");\n    // We have to randomly pick one engine from [0, total_engines)\n    // Selecting \"0\" by default\n    auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();\n    DEBUG_CUDNN_MSG(log_buf, engine.describe());\n    auto& knobs = engine.getSupportedKnobs();\n    for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {\n      DEBUG_CUDNN_MSG(log_buf, it->describe());\n    }\n    if (knobs.begin() != knobs.end()) {\n      DEBUG_CUDNN_MSG(log_buf, \"Updated knob choice\");\n      knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);\n      DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());\n    }\n\n    // Createmplacee the requisite engine config\n    auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();\n    DEBUG_CUDNN_MSG(log_buf, engine_config.describe());\n    plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));\n    }\n\n    return plan_cache.find(cache_string)->second;\n  }\n}\n\nvoid\nrun_conv_scale_bias_add_activation(int64_t* x_dim_padded,\n                                   int64_t* pad,\n                                   int64_t* convstride,\n                                   int64_t* dilation,\n                                   int64_t* w_dim_padded,\n                                   int64_t* y_dim_padded,\n                                   cudnnDataType_t dataType,\n                                   at::Half* devPtrX,\n                                   at::Half* devPtrW,\n                                   at::Half* devPtrY,\n                                   at::Half* devPtrZ,\n                                   at::Half* devPtrB,\n                                   at::Half* devPtrI) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n\n        // Define the add operation\n        auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()\n                           .setMode(CUDNN_POINTWISE_MUL)\n                           .setMathPrecision(CUDNN_DATA_FLOAT)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n        // Define the bias operation\n        auto biasDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_ADD)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n        // optional add\n        auto addDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_ADD)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n        // Define the activation operation\n        auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                           .setMode(CUDNN_POINTWISE_RELU_FWD)\n                           .setMathPrecision(CUDNN_DATA_FLOAT)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                           .setxDesc(std::get<X_TENSOR>(tensors))\n                           .setwDesc(std::get<W_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                           .setcDesc(convDesc)\n                           .setAlpha(alpha)\n                           .setBeta(beta)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // Create a Add Node with scaling parameters.\n        auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(conv_op.getOutputTensor())\n                           .setbDesc(std::get<Z_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERADD_TENSOR>(tensors))\n                           .setpwDesc(scaleDesc)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n        // Create a Bias Node.\n        auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(scale_op.getOutputTensor())\n                           .setbDesc(std::get<B_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))\n                           .setpwDesc(biasDesc)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n        // Create a optional add Node.\n        auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(bias_op.getOutputTensor())\n                           .setbDesc(std::get<OPTIONAL>(tensors))\n                           .setyDesc(std::get<AFTEROPT_TENSOR>(tensors))\n                           .setpwDesc(addDesc)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n\n        // Create an Activation Node.\n        auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())\n                          .setyDesc(std::get<Y_TENSOR>(tensors))\n                          .setpwDesc(actDesc)\n                          .build();\n        DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(devPtrI ? ops.size() : 4, ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};\n        int64_t uids[]    = {'x', 'y', 'w', 'z', 'b', 'i'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n                               .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(devPtrI ? 6 : 5, data_ptrs)\n          .setUids(devPtrI ? 6 : 5, uids)\n                               .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\nvoid\nrun_conv_scale_bias(int64_t* x_dim_padded,\n                    int64_t* pad,\n                    int64_t* convstride,\n                    int64_t* dilation,\n                    int64_t* w_dim_padded,\n                    int64_t* y_dim_padded,\n                    cudnnDataType_t dataType,\n                    at::Half* devPtrX,\n                    at::Half* devPtrW,\n                    at::Half* devPtrY,\n                    at::Half* devPtrZ,\n                    at::Half* devPtrB) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n\n        // Define the add operation\n        auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_MUL)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n        // Define the bias operation\n        auto addDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_ADD)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                           .setxDesc(std::get<X_TENSOR>(tensors))\n                           .setwDesc(std::get<W_TENSOR>(tensors))\n                           .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                           .setcDesc(convDesc)\n                           .setAlpha(alpha)\n                           .setBeta(beta)\n                           .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // Create a Add Node with scaling parameters.\n        auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(conv_op.getOutputTensor())\n          .setbDesc(std::get<Z_TENSOR>(tensors))\n          .setyDesc(std::get<AFTERADD_TENSOR>(tensors)) // TODO: change enum to aftermul\n          .setpwDesc(scaleDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n        // Create a Bias Node.\n        auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(scale_op.getOutputTensor())\n          .setbDesc(std::get<B_TENSOR>(tensors))\n          .setyDesc(std::get<Y_TENSOR>(tensors))\n          .setpwDesc(addDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &scale_op, &add_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB};\n        int64_t uids[]    = {'x', 'y', 'w', 'z', 'b'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n                               .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(5, data_ptrs)\n          .setUids(5, uids)\n                               .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\n\nvoid\nrun_dconv_drelu_dscale(int64_t* x_dim_padded,\n                       int64_t* pad,\n                       int64_t* convstride,\n                       int64_t* dilation,\n                       int64_t* w_dim_padded,\n                       int64_t* y_dim_padded,\n                       cudnnDataType_t dataType,\n                       at::Half* devPtrX,\n                       at::Half* devPtrW,\n                       at::Half* devPtrY,\n                       at::Half* devPtrZ,\n                       at::Half* devPtrR) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        dconv_descriptors tensors = create_dconv_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        // Define the activation backward operation\n        auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_RELU_BWD)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n        // Define the scale backward operation\n        auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_MUL)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n          .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n          .setdyDesc(std::get<DY_TENSOR>(tensors))\n          .setcDesc(convDesc)\n          .setAlpha(alpha)\n          .setBeta(beta)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // TODO: do we need getOutputTensor(), and what it returns in backward case?\n        // Create an relu backward Node.\n        auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setxDesc(std::get<RELU_TENSOR>(tensors))\n          .setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n          .setpwDesc(actDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n        // Create a Scale Node.\n        auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n          .setbDesc(std::get<SCALE_TENSOR>(tensors))\n          .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n          .setpwDesc(scaleDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &scale_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR};\n        int64_t uids[]    = {'x', 'y', 'w', 's', 'r'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n          .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(5, data_ptrs)\n          .setUids(5, uids)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\nvoid\nrun_dconv(int64_t* x_dim_padded,\n          int64_t* pad,\n          int64_t* convstride,\n          int64_t* dilation,\n          int64_t* w_dim_padded,\n          int64_t* y_dim_padded,\n          cudnnDataType_t dataType,\n          at::Half* devPtrX,\n          at::Half* devPtrW,\n          at::Half* devPtrY,\n          cudnnBackendDescriptorType_t mode) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        dconv_descriptors tensors = create_dconv_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        // mode should be one of following\n        // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR\n        // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR\n        auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);\n        if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {\n          conv_op_builder.setdxDesc(std::get<X_OR_DX_TENSOR>(tensors))\n            .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n            .setdyDesc(std::get<DY_TENSOR>(tensors))\n            .setcDesc(convDesc)\n            .setAlpha(alpha)\n            .setBeta(beta);\n        }\n        else {\n          conv_op_builder.setxDesc(std::get<X_OR_DX_TENSOR>(tensors))\n            .setdwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n            .setdyDesc(std::get<DY_TENSOR>(tensors))\n            .setcDesc(convDesc)\n            .setAlpha(alpha)\n            .setBeta(beta);\n        }\n        auto conv_op = conv_op_builder.build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW};\n        int64_t uids[]    = {'x', 'y', 'w'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n          .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(3, data_ptrs)\n          .setUids(3, uids)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\nvoid\nrun_dconv_add(int64_t* x_dim_padded,\n              int64_t* pad,\n              int64_t* convstride,\n              int64_t* dilation,\n              int64_t* w_dim_padded,\n              int64_t* y_dim_padded,\n              cudnnDataType_t dataType,\n              at::Half* devPtrX,\n              at::Half* devPtrW,\n              at::Half* devPtrY,\n              at::Half* devPtrR) {\n    cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n    std::stringstream log_buf;\n    try {\n        int convDim = 2;\n\n        // Creates the necessary tensor descriptors\n        dconv_descriptors tensors = create_dconv_descriptors(\n            x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n        DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n        DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n        // Define the convolution problem\n        auto convDesc = cudnn_frontend::ConvDescBuilder()\n                            .setDataType(CUDNN_DATA_FLOAT)\n                            .setMathMode(CUDNN_CROSS_CORRELATION)\n                            .setNDims(convDim)\n                            .setStrides(convDim, convstride)\n                            .setPrePadding(convDim, pad)\n                            .setPostPadding(convDim, pad)\n                            .setDilation(convDim, dilation)\n                            .build();\n        DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n        // Define the add backward operation\n        auto addDesc = cudnn_frontend::PointWiseDescBuilder()\n          .setMode(CUDNN_POINTWISE_ADD)\n          .setMathPrecision(CUDNN_DATA_FLOAT)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n        float alpha  = 1.0f;\n        float beta   = 0.0f;\n\n        // Create a convolution Node\n        auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n          .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n          .setdyDesc(std::get<DY_TENSOR>(tensors))\n          .setcDesc(convDesc)\n          .setAlpha(alpha)\n          .setBeta(beta)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n        // TODO: do we need getOutputTensor(), and what it returns in backward case?\n        // Create add Node.\n        auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n          .setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n          .setbDesc(std::get<RELU_TENSOR>(tensors))\n          .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n          .setpwDesc(addDesc)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n        // Create an Operation Graph. In this case it is convolution add bias activation\n        std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &add_op};\n\n        auto opGraph = cudnn_frontend::OperationGraphBuilder()\n          .setHandle(handle_)\n          .setOperationGraph(ops.size(), ops.data())\n          .build();\n\n        // Create string encoding for plan caching\n        auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n        DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n        auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n        DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n        auto workspace_size = plan.getWorkspaceSize();\n        DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n        void* workspace_ptr = nullptr;\n        auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n        if (workspace_size > 0) {\n          workspace_ptr = workspace_tensor.data_ptr<float>();\n        }\n        void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR};\n        int64_t uids[]    = {'x', 'y', 'w', 'r'};\n        auto variantPack  = cudnn_frontend::VariantPackBuilder()\n          .setWorkspacePointer(workspace_ptr)\n          .setDataPointers(4, data_ptrs)\n          .setUids(4, uids)\n          .build();\n        DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n        cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n        checkCudnnErr(status);\n        cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\");\n    } catch (cudnn_frontend::cudnnException e) {\n      std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n    }\n}\n\n\n// inputs contains x,w,z,b,(i)\nstd::vector<at::Tensor> bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n\n  std::cout << std::fixed;\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t dimA[]         = {0, 0, 0, 0};\n  int64_t filterdimA1[]  = {0, 0, 0, 0};\n  int64_t filterdimA2[]  = {0, 0, 0, 0};\n  int64_t filterdimA3[]  = {0, 0, 0, 0};\n  int64_t filterdimA4[]  = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] {0,1,2,3};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 3;\n    axis[2] = 1;\n    axis[3] = 2;\n  }\n  for (int dim=0;dim<4;dim++) {\n    dimA[dim] = inputs[0].size(axis[dim]);\n    filterdimA1[dim] = inputs[1].size(axis[dim]);\n    filterdimA2[dim] = inputs[2].size(axis[dim]);\n    filterdimA3[dim] = inputs[3].size(axis[dim]);\n  }\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    for (int dim=0;dim<4;dim++) {\n      filterdimA4[dim] = inputs[10].size(axis[dim]);\n    }\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t outdimA1[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA2[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA3[]     = {0, 0, 0, 0}; // Computed Below\n\n  // use these fixed value for test run\n  int64_t padA[]        = {0, 0};\n  int64_t padA1[]        = {1, 1};\n  int64_t dilationA[] = {1, 1};\n  int64_t convstrideA[] = {1, 1};\n  int64_t convstride1X1[] = {stride_1X1, stride_1X1};\n\n  // compute output from pad/stride/dilation\n  outdimA1[0] = dimA[0];\n  outdimA1[1] = filterdimA1[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n  }\n\n  outdimA2[0] = outdimA1[0];\n  outdimA2[1] = filterdimA2[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  outdimA3[0] = outdimA2[0];\n  outdimA3[1] = filterdimA3[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  // Create output tensor in the correct shape in pytorch's view\n  int64_t outdim1[]     = {0, 0, 0, 0};\n  int64_t outdim2[]     = {0, 0, 0, 0};\n  int64_t outdim3[]     = {0, 0, 0, 0};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 2;\n    axis[2] = 3;\n    axis[3] = 1;\n  }\n  for (int dim=0;dim<4;dim++) {\n    outdim1[dim] = outdimA1[axis[dim]];\n    outdim2[dim] = outdimA2[axis[dim]];\n    outdim3[dim] = outdimA3[axis[dim]];\n  }\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n  at::Half* b = inputs[7].data_ptr<at::Half>();\n  auto out1 = at::empty(outdim1, inputs[0].type(), output_format);\n  at::Half* y1 = out1.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(dimA,\n                                     padA,\n                                     convstride1X1,\n                                     dilationA,\n                                     filterdimA1,\n                                     outdimA1,\n                                     CUDNN_DATA_HALF,\n                                     x,\n                                     w,\n                                     y1,\n                                     z,\n                                     b,\n                                     nullptr);\n\n  DEBUG_MSG(\"[DEBUG] new relu1 : \" << out1.to(at::kFloat).sum().item<float>());\n\n  w = inputs[2].data_ptr<at::Half>();\n  z = inputs[5].data_ptr<at::Half>();\n  b = inputs[8].data_ptr<at::Half>();\n  auto out2 = at::empty(outdim2, inputs[0].type(), output_format);\n  at::Half* y2 = out2.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(outdimA1,\n                                     padA1,\n                                     convstrideA,\n                                     dilationA,\n                                     filterdimA2,\n                                     outdimA2,\n                                     CUDNN_DATA_HALF,\n                                     y1,\n                                     w,\n                                     y2,\n                                     z,\n                                     b,\n                                     nullptr);\n  DEBUG_MSG(\"[DEBUG] new relu2 : \" << out2.to(at::kFloat).sum().item<float>());\n\n  // create output of conv3\n  auto out3 = at::empty(outdim3, inputs[0].type(), output_format);\n  at::Half* y3 = out3.data_ptr<at::Half>();\n\n  // create output of conv4 that may exist\n  auto identity = at::empty_like(out3);\n  at::Half* yi = identity.data_ptr<at::Half>();\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){\n\n    w = inputs[10].data_ptr<at::Half>();\n    z = inputs[11].data_ptr<at::Half>();\n    b = inputs[12].data_ptr<at::Half>();\n    run_conv_scale_bias(dimA,\n                        padA,\n                        convstride1X1,\n                        dilationA,\n                        filterdimA4,\n                        outdimA3,\n                        CUDNN_DATA_HALF,\n                        x,\n                        w,\n                        yi,\n                        z,\n                        b);\n    DEBUG_MSG(\"[DEBUG] new downsample : \" << identity.to(at::kFloat).sum().item<float>());\n  }\n  else {\n    yi = x;\n  }\n\n  w = inputs[3].data_ptr<at::Half>();\n  z = inputs[6].data_ptr<at::Half>();\n  b = inputs[9].data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(outdimA2,\n                                     padA,\n                                     convstrideA,\n                                     dilationA,\n                                     filterdimA3,\n                                     outdimA3,\n                                     CUDNN_DATA_HALF,\n                                     y2,\n                                     w,\n                                     y3,\n                                     z,\n                                     b,\n                                     yi);\n  DEBUG_MSG(\"[DEBUG] new relu3 : \" << out3.to(at::kFloat).sum().item<float>());\n\n  outputs.push_back(out1);\n  outputs.push_back(out2);\n  outputs.push_back(out3);\n\n  return outputs;\n}\n\nstd::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t dimA[]         = {0, 0, 0, 0};\n  int64_t filterdimA1[]  = {0, 0, 0, 0};\n  int64_t filterdimA2[]  = {0, 0, 0, 0};\n  int64_t filterdimA3[]  = {0, 0, 0, 0};\n  int64_t filterdimA4[]  = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] {0,1,2,3};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 3;\n    axis[2] = 1;\n    axis[3] = 2;\n  }\n  for (int dim=0;dim<4;dim++) {\n    dimA[dim] = inputs[0].size(axis[dim]);\n    filterdimA1[dim] = inputs[1].size(axis[dim]);\n    filterdimA2[dim] = inputs[2].size(axis[dim]);\n    filterdimA3[dim] = inputs[3].size(axis[dim]);\n  }\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    for (int dim=0;dim<4;dim++) {\n      filterdimA4[dim] = inputs[14].size(axis[dim]);\n    }\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t outdimA1[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA2[]     = {0, 0, 0, 0}; // Computed Below\n  int64_t outdimA3[]     = {0, 0, 0, 0}; // Computed Below\n\n  // use these fixed value for test run\n  int64_t padA[]        = {0, 0};\n  int64_t padA1[]        = {1, 1};\n  int64_t dilationA[] = {1, 1};\n  int64_t convstrideA[] = {1, 1};\n  int64_t convstride1X1[] = {stride_1X1, stride_1X1};\n\n  // compute output from pad/stride/dilation\n  outdimA1[0] = dimA[0];\n  outdimA1[1] = filterdimA1[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n  }\n\n  outdimA2[0] = outdimA1[0];\n  outdimA2[1] = filterdimA2[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  outdimA3[0] = outdimA2[0];\n  outdimA3[1] = filterdimA3[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  // Create output tensor in the correct shape in pytorch's view\n  int64_t outdim1[]     = {0, 0, 0, 0};\n  int64_t outdim2[]     = {0, 0, 0, 0};\n  int64_t outdim3[]     = {0, 0, 0, 0};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 2;\n    axis[2] = 3;\n    axis[3] = 1;\n  }\n  for (int dim=0;dim<4;dim++) {\n    outdim1[dim] = outdimA1[axis[dim]];\n    outdim2[dim] = outdimA2[axis[dim]];\n    outdim3[dim] = outdimA3[axis[dim]];\n  }\n\n  // dconv3+drelu2+dscale2\n  at::Half* conv_in = inputs[13].data_ptr<at::Half>();\n  at::Half* dy3 = inputs[10].data_ptr<at::Half>();\n\n  DEBUG_MSG(\"[DEBUG] new dconv3 : \" << inputs[10].to(at::kFloat).sum().item<float>());\n\n  // wgrad\n  auto wgrad3 = at::empty_like(inputs[3]);\n  at::Half* dw3 = wgrad3.data_ptr<at::Half>();\n  run_dconv(outdimA2,\n            padA,\n            convstrideA,\n            dilationA,\n            filterdimA3,\n            outdimA3,\n            CUDNN_DATA_HALF,\n            conv_in,\n            dw3,\n            dy3,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format);\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n  at::Half* w = inputs[3].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n\n  at::Half* relu2 = inputs[13].data_ptr<at::Half>();\n\n  run_dconv_drelu_dscale(outdimA2,\n                         padA,\n                         convstrideA,\n                         dilationA,\n                         filterdimA3,\n                         outdimA3,\n                         CUDNN_DATA_HALF,\n                         dy2,\n                         w,\n                         dy3,\n                         z,\n                         relu2);\n\n  DEBUG_MSG(\"[DEBUG] new dconv2 : \" << grad_out2.to(at::kFloat).sum().item<float>());\n\n  // dconv2+drelu1+dscale1\n  conv_in = inputs[12].data_ptr<at::Half>();\n\n  // wgrad\n  auto wgrad2 = at::empty_like(inputs[2]);\n  at::Half* dw2 = wgrad2.data_ptr<at::Half>();\n  run_dconv(outdimA1,\n            padA1,\n            convstrideA,\n            dilationA,\n            filterdimA2,\n            outdimA2,\n            CUDNN_DATA_HALF,\n            conv_in,\n            dw2,\n            dy2,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format);\n  at::Half* dy1 = grad_out1.data_ptr<at::Half>();\n  w = inputs[2].data_ptr<at::Half>();\n  z = inputs[4].data_ptr<at::Half>();\n\n  at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n  // fused dgrad\n  run_dconv_drelu_dscale(outdimA1,\n                         padA1,\n                         convstrideA,\n                         dilationA,\n                         filterdimA2,\n                         outdimA2,\n                         CUDNN_DATA_HALF,\n                         dy1,\n                         w,\n                         dy2,\n                         z,\n                         relu1);\n\n/*\n  // backward strided conv cannot be fused\n  // if stride == 1 but channel changes, we can fuse here\n  if (stride_1X1 != 1){\n    // dgrad\n    run_dconv(outdimA1,\n              padA1,\n              convstride1X1,\n              dilationA,\n              filterdimA2,\n              outdimA2,\n              CUDNN_DATA_HALF,\n              dy1,\n              w,\n              dy2,\n              CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n\n    // mul fused mask\n    grad_out1.mul_(inputs[15]);\n  }\n  else {\n    at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n    // fused dgrad\n    run_dconv_drelu_dscale(outdimA1,\n                           padA1,\n                           convstride1X1,\n                           dilationA,\n                           filterdimA2,\n                           outdimA2,\n                           CUDNN_DATA_HALF,\n                           dy1,\n                           w,\n                           dy2,\n                           z,\n                           relu1);\n  }\n*/\n  DEBUG_MSG(\"[DEBUG] new dconv1 : \" << grad_out1.to(at::kFloat).sum().item<float>());\n\n  // create grads of conv4 that may exist\n  auto grad_x_conv4 = at::empty_like(inputs[0]);\n  at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();\n  at::Tensor wgrad4;\n\n  // x used for dconv1 and dconv4 wgrad\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){\n    w = inputs[14].data_ptr<at::Half>();\n    at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();\n    if (requires_grad) {\n      run_dconv(dimA,\n                padA,\n                convstride1X1,\n                dilationA,\n                filterdimA4,\n                outdimA3,\n                CUDNN_DATA_HALF,\n                dx_conv4,\n                w,\n                dy_conv4,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx\n      // DEBUG_MSG(\"[DEBUG] new dx_identity : \" << grad_x_conv4.to(at::kFloat).sum().item<float>());\n    }\n    // wgrad\n    wgrad4 = at::empty_like(inputs[14]);\n    at::Half* dw4 = wgrad4.data_ptr<at::Half>();\n    run_dconv(dimA,\n              padA,\n              convstride1X1,\n              dilationA,\n              filterdimA4,\n              outdimA3,\n              CUDNN_DATA_HALF,\n              x,\n              dw4,\n              dy_conv4,\n              CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n  }\n  else {\n    // if there is no downsample, dx_conv4 is fork of drelu3\n    dx_conv4 = inputs[11].data_ptr<at::Half>();\n  }\n\n  // dconv1+add\n  // wgrad\n  auto wgrad1 = at::empty_like(inputs[1]);\n  at::Half* dw1 = wgrad1.data_ptr<at::Half>();\n  run_dconv(dimA,\n            padA,\n            convstride1X1,\n            dilationA,\n            filterdimA1,\n            outdimA1,\n            CUDNN_DATA_HALF,\n            x,\n            dw1,\n            dy1,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  w = inputs[1].data_ptr<at::Half>();\n  auto grad_x = at::empty_like(inputs[0]);\n  at::Half* dx = grad_x.data_ptr<at::Half>();\n\n  // backward strided conv cannot be fused\n  // if stride == 1 but channel changes, we can fuse here\n  if (requires_grad){\n    if (stride_1X1 != 1){\n      run_dconv(dimA,\n                padA,\n                convstride1X1,\n                dilationA,\n                filterdimA1,\n                outdimA1,\n                CUDNN_DATA_HALF,\n                dx,\n                w,\n                dy1,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // add 2 together\n      grad_x.add_(grad_x_conv4);\n    }\n    else {\n      run_dconv_add(dimA,\n                    padA,\n                    convstride1X1,\n                    dilationA,\n                    filterdimA1,\n                    outdimA1,\n                    CUDNN_DATA_HALF,\n                    dx,\n                    w,\n                    dy1,\n                    dx_conv4);\n    }\n  }\n\n  DEBUG_MSG(\"[DEBUG] new dx : \" << grad_x.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad1 : \" << wgrad1.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad2 : \" << wgrad2.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad3 : \" << wgrad3.to(at::kFloat).sum().item<float>());\n  outputs.push_back(grad_x);\n  outputs.push_back(wgrad1);\n  outputs.push_back(wgrad2);\n  outputs.push_back(wgrad3);\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    DEBUG_MSG(\"[DEBUG] new wgrad4 : \" << wgrad4.to(at::kFloat).sum().item<float>());\n    outputs.push_back(wgrad4);\n  }\n\n  return outputs;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &bottleneck_forward, \"Bottleneck block forward\");\n  m.def(\"backward\", &bottleneck_backward, \"Bottleneck block backward\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/fmha_api.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include <torch/extension.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include \"fmha.h\"\n\nvoid set_params(Fused_multihead_attention_fprop_params &params,\n                // sizes\n                const size_t b,\n                const size_t s,\n                const size_t h,\n                const size_t d,\n                // device pointers\n                void *qkv_packed_d,\n                void *cu_seqlens_d,\n                void *o_packed_d,\n                void *s_d,\n                float p_dropout) {\n\n    Data_type acc_type = DATA_TYPE_FP32;\n    Data_type data_type = DATA_TYPE_FP16;\n\n    // Reset the parameters\n    memset(&params, 0, sizeof(params));\n\n    // Set the pointers and strides.\n    params.qkv_ptr = qkv_packed_d;\n    params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);\n    params.o_ptr = o_packed_d;\n    params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);\n\n    params.cu_seqlens = static_cast<int *>(cu_seqlens_d);\n\n    // S = softmax(P)\n    params.s_ptr = s_d;\n    params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);\n\n    // Set the dimensions.\n    params.b = b;\n    params.h = h;\n    params.s = s;\n    params.d = d;\n\n    // Set the different scale values.\n    const float scale_bmm1 = 1.f / sqrtf(d);\n    constexpr float scale_softmax = 1.f;\n    constexpr float scale_bmm2 = 1.f;\n\n    set_alpha(params.scale_bmm1, scale_bmm1, acc_type);\n    set_alpha(params.scale_softmax, scale_softmax, acc_type);\n    set_alpha(params.scale_bmm2, scale_bmm2, data_type);\n\n    // Set this to probability of keeping an element to simplify things.\n    params.p_dropout = 1.f - p_dropout;\n    params.rp_dropout = 1.f / params.p_dropout;\n    TORCH_CHECK(p_dropout < 1.f);\n    set_alpha(params.scale_dropout, params.rp_dropout, data_type);\n}\n\nstd::vector<at::Tensor>\nmha_fwd(const at::Tensor &qkv,  // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n        const at::Tensor &cu_seqlens,  // b+1\n        const float p_dropout,\n        const int max_seq_len,\n        const bool is_training,\n        c10::optional<at::Generator> gen_) {\n    auto dprops = at::cuda::getCurrentDeviceProperties();\n    TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);\n    int seq_len = 512;\n    auto launch = &run_fmha_fp16_512_64_sm80;\n    if( max_seq_len <= 128 ) {\n        seq_len = 128;\n        launch = &run_fmha_fp16_128_64_sm80;\n    } else if( max_seq_len <= 256 ) {\n        seq_len = 256;\n        launch = &run_fmha_fp16_256_64_sm80;\n    } else if( max_seq_len <= 384 ) {\n        seq_len = 384;\n        launch = &run_fmha_fp16_384_64_sm80;\n    } else if( max_seq_len <= 512 ) {\n        seq_len = 512;\n        launch = &run_fmha_fp16_512_64_sm80;\n    } else {\n        TORCH_CHECK(false);\n    }\n\n    constexpr int warps_m = 1;\n    constexpr int warps_n = 4;  // this leads to an upper bound\n    const int mmas_m = seq_len / 16 / warps_m;\n    const int mmas_n = seq_len / 16 / warps_n;\n    \n    const int elts_per_thread = 8 * mmas_m * mmas_n;\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.dtype() == torch::kFloat16);\n    TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);\n\n    TORCH_CHECK(qkv.is_cuda())\n    TORCH_CHECK(cu_seqlens.is_cuda())\n\n    TORCH_CHECK(qkv.is_contiguous())\n    TORCH_CHECK(cu_seqlens.is_contiguous())\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    const int total = sizes[TOTAL_DIM];\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n    auto opts = qkv.options();\n\n    auto ctx = torch::empty({ total, num_heads, head_size }, opts);\n\n    auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               ctx.data_ptr(),\n               s.data_ptr(),\n               p_dropout);\n\n    // number of times random will be generated per thread, to offset philox counter in thc random\n    // state\n    int64_t counter_offset = elts_per_thread;\n    at::PhiloxCudaState rng_engine_inputs;\n\n    if( is_training ) {\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n    }\n\n    launch(params, is_training, stream);\n\n    return { ctx, s };\n}\n\nstd::vector<at::Tensor>\nmha_bwd(const at::Tensor &dout,  // total x num_heads, x head_size\n        const at::Tensor &qkv,   // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n        at::Tensor &softmax,     // b x h x s x s softmax and dmask - will be overwritten with dP\n        const at::Tensor &cu_seqlens,  // b+1\n        const float p_dropout,         // probability to drop\n        const int max_seq_len          // max sequence length to choose the kernel\n) {\n    auto dprops = at::cuda::getCurrentDeviceProperties();\n    TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);\n    int seq_len = 512;\n    auto launch = &run_fmha_dgrad_fp16_512_64_sm80;\n    if( max_seq_len <= 128 ) {\n        seq_len = 128;\n        launch = &run_fmha_dgrad_fp16_128_64_sm80;\n    } else if( max_seq_len <= 256 ) {\n        seq_len = 256;\n        launch = &run_fmha_dgrad_fp16_256_64_sm80;\n    } else if( max_seq_len <= 384 ) {\n        seq_len = 384;\n        launch = &run_fmha_dgrad_fp16_384_64_sm80;\n    } else if( max_seq_len <= 512 ) {\n        seq_len = 512;\n        launch = &run_fmha_dgrad_fp16_512_64_sm80;\n    } else {\n        TORCH_CHECK(false);\n    }\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.dtype() == torch::kFloat16);\n    TORCH_CHECK(dout.dtype() == torch::kFloat16);\n    TORCH_CHECK(softmax.dtype() == torch::kFloat16);\n    TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);\n\n    TORCH_CHECK(qkv.is_cuda());\n    TORCH_CHECK(cu_seqlens.is_cuda());\n\n    TORCH_CHECK(qkv.is_contiguous());\n    TORCH_CHECK(cu_seqlens.is_contiguous());\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n\n    auto dqkv = torch::empty_like(qkv);\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               dout.data_ptr(),     // we set o_ptr to dout\n               softmax.data_ptr(),  // softmax gets overwritten by dP!\n               p_dropout);\n\n    // we're re-using these scales\n    Data_type acc_type = DATA_TYPE_FP32;\n    set_alpha(params.scale_bmm1, 1.f, acc_type);\n    set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);\n    set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);\n    params.dqkv_ptr = dqkv.data_ptr();\n\n    launch(params, stream);\n    return { dqkv, softmax };\n}\n\nstd::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n                                const at::Tensor &cu_seqlens,  // b+1\n                                const float p_dropout,\n                                const int max_seq_len,\n                                const bool is_training,\n                                c10::optional<at::Generator> gen_) {\n    int seq_len = 512;\n    auto launch = &run_fmha_fp16_512_64_sm80_nl;\n    TORCH_CHECK(max_seq_len == seq_len);\n\n    constexpr int warps_m = 1;\n    constexpr int warps_n = 4;  // this leads to an upper bound\n    const int mmas_m = seq_len / 16 / warps_m;\n    const int mmas_n = seq_len / 16 / warps_n;\n    // static_assert( mmas_m == 32 );\n    // static_assert( mmas_n == 4 );\n    const int elts_per_thread = 8 * mmas_m * mmas_n;\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.is_cuda())\n    TORCH_CHECK(cu_seqlens.is_cuda())\n\n    TORCH_CHECK(qkv.is_contiguous())\n    TORCH_CHECK(cu_seqlens.is_contiguous())\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    const int total = sizes[TOTAL_DIM];\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n    auto opts = qkv.options();\n\n    auto ctx = torch::empty({ total, num_heads, head_size }, opts);\n\n    auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               ctx.data_ptr(),\n               s.data_ptr(),\n               p_dropout);\n\n    // number of times random will be generated per thread, to offset philox counter in thc random\n    // state\n    int64_t counter_offset = elts_per_thread;\n    at::PhiloxCudaState rng_engine_inputs;\n\n    if( is_training ) {\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n    }\n    int num_chunks = 3;\n    if(batch_size == 3) {\n        num_chunks = 2;\n    }\n\n    launch(params, is_training, num_chunks, stream);\n\n    return { ctx, s };\n}\n\nstd::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout,        // total x num_heads, x head_size\n                                const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n                                at::Tensor &softmax,           // b x h x s x s softmax and dmask - will be overwritten with dP\n                                const at::Tensor &cu_seqlens,  // b+1\n                                const float p_dropout,         // probability to drop\n                                const int max_seq_len          // max sequence length to choose the kernel\n) {\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    TORCH_CHECK(qkv.is_cuda())\n    TORCH_CHECK(cu_seqlens.is_cuda())\n\n    TORCH_CHECK(qkv.is_contiguous())\n    TORCH_CHECK(cu_seqlens.is_contiguous())\n\n    TORCH_CHECK(cu_seqlens.dim() == 1);\n\n    TORCH_CHECK(qkv.dim() == 4);\n\n    const auto sizes = qkv.sizes();\n\n    TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n    const int batch_size = cu_seqlens.numel() - 1;\n    \n    const int total = sizes[TOTAL_DIM];\n    const int num_heads = sizes[H_DIM];\n    const int head_size = sizes[D_DIM];\n    TORCH_CHECK(batch_size > 0);\n    TORCH_CHECK(head_size == 64);\n\n    int seq_len = 512;\n    auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;\n\n    auto opts = qkv.options();\n\n    auto dqkv = torch::empty_like(qkv);\n\n    int num_chunks = 2;\n    if( batch_size == 1 ) {\n        num_chunks = 4;\n    }else if( batch_size == 2 ) {\n        num_chunks = 3;\n    }\n    auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);\n\n    Fused_multihead_attention_fprop_params params;\n\n    set_params(params,\n               batch_size,\n               seq_len,\n               num_heads,\n               head_size,\n               qkv.data_ptr(),\n               cu_seqlens.data_ptr(),\n               dout.data_ptr(),     // o_ptr = dout\n               softmax.data_ptr(),  // softmax gets overwritten by dP!\n               p_dropout);\n\n    params.dkv_ptr = dkv.data_ptr();\n\n    Data_type acc_type = DATA_TYPE_FP32;\n    set_alpha(params.scale_bmm1, 1.f, acc_type);\n    set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);\n    set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);\n    params.dqkv_ptr = dqkv.data_ptr();\n\n    launch(params, num_chunks, stream);\n\n    //SPLIT-K reduction of num_chunks dK, dV parts\n\n    // The equivalent of the following Pytorch code:\n    // using namespace torch::indexing;\n    // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});\n    // torch::sum_out(view_out, dkv, 1);\n\n    const int hidden_size = num_heads * head_size;\n    fmha_run_noloop_reduce(\n        dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);\n\n    return { dqkv, softmax, dkv };\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.doc() = \"Fused Multi-head Self-attention for BERT\";  \n    m.def(\"fwd\", &mha_fwd, \"Forward pass\");\n    m.def(\"bwd\", &mha_bwd, \"Backward pass\");\n    m.def(\"fwd_nl\", &mha_fwd_nl, \"Forward pass (small-batch)\");\n    m.def(\"bwd_nl\", &mha_bwd_nl, \"Backward pass (small-batch)\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha/gemm.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/utils.h>\n\n#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n))\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >\nstruct Fragment_base_ {\n\n    // The data type.\n    using Data_type = Data_type_;\n    // default input type\n    using Input_type_ = Data_type_;\n    // Does it store the array of elements.\n    enum { HAS_ELTS = BITS_PER_ELT_ >= 8 };\n    // The number of elements.\n    enum { NUM_ELTS = NUM_ELTS_ };\n    // The size of element in bits.\n    enum { BITS_PER_ELT = BITS_PER_ELT_ };\n    // The size of byte of a single register.\n    enum { BYTES_PER_REG = 4 };\n    // The size in bits.\n    enum { BITS_PER_REG = BYTES_PER_REG * 8 };\n    // The number of registers needed to store the fragment.\n    enum { NUM_REGS = Div_up<NUM_ELTS * BITS_PER_ELT, BITS_PER_REG>::VALUE };\n    // The size in bytes (as returned by sizeof(Fragment_base<>).\n    enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG };\n    // The alignment.\n    enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min<NUM_REGS * BYTES_PER_REG, 16>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The type of the elements.\n    typename Data_type_,\n    // The number of elements.\n    int NUM_ELTS_,\n    // The alignment if you want to force a value -- use 0 otherwise.\n    int ALIGNMENT_ = 0,\n    // The base class.\n    typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>\n>\nstruct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {\n\n    // The size of a load/store.\n    enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) };\n\n    // Clear the fragment. Using PTX in that code seems to produce better SASS...\n    inline __device__ void clear() {\n        #pragma unroll\n        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {\n            asm volatile(\"mov.u32 %0, 0; \\n\" : \"=r\"(this->reg(ii)) : );\n        }\n    }\n\n    // Immutable access to a register.\n    inline __device__ const uint32_t& reg(int ii) const {\n        return this->regs_[ii];\n    }\n\n    // Mutable access to a register.\n    inline __device__ uint32_t& reg(int ii) {\n        return this->regs_[ii];\n    }\n\n    uint32_t regs_[Base_::NUM_REGS];\n\n    // Immutable access to the elements.\n    inline __device__ const Data_type_& elt(int ii) const {\n        return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];\n    }\n\n    // Mutable access to the elements.\n    inline __device__ Data_type_& elt(int ii) {\n        return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];\n    }\n\n    // Immutable access to the elements with a cast.\n    template< typename Cast_type >\n    inline __device__ const Cast_type& elt_as(int ii) const {\n        return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];\n    }\n\n    // Mutable access to the elements.\n    template< typename Cast_type >\n    inline __device__ Cast_type& elt_as(int ii) {\n        return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];\n    }\n\n    // Add another fragment.\n    inline __device__ void add(const Fragment &other) {\n        #pragma unroll\n        for( int ii = 0; ii < NUM_ELTS_; ++ii ) {\n            this->elt(ii) += other.elt(ii);\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Layout >\nstruct Fragment_a : public Fragment<uint16_t, 8> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Layout >\nstruct Fragment_b : public Fragment<uint16_t, 8> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fragment_accumulator : public Fragment<float, 8> {\n\n    // The base class.\n    using Base = Fragment<float, 8>;\n\n    // Add two fragments.\n    template< typename Other_fragment_ >\n    inline __device__ void add(const Other_fragment_ &other) {\n        for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {\n            this->elt(ii) = this->elt(ii) + other.elt(ii);\n        }\n    }\n\n    // Do the HMMA.\n    template< typename Layout_a, typename Layout_b >\n    inline __device__ void mma(const Fragment_a<Layout_a> &a,\n                               const Fragment_b<Layout_b> &b) {\n        asm volatile( \\\n            \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \\n\" \\\n            \"    {%0, %1, %2, %3}, \\n\" \\\n            \"    {%4, %5, %6, %7}, \\n\" \\\n            \"    {%8, %9}, \\n\" \\\n            \"    {%0, %1, %2, %3}; \\n\" \\\n                    : \"+f\"(  elt(0)), \"+f\"(  elt(1)), \"+f\"(  elt(2)), \"+f\"(  elt(3))\n                    :  \"r\"(a.reg(0)),  \"r\"(a.reg(1)),  \"r\"(a.reg(2)),  \"r\"(a.reg(3))\n                    ,  \"r\"(b.reg(0)),  \"r\"(b.reg(1)));\n        asm volatile( \\\n            \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \\n\" \\\n            \"    {%0, %1, %2, %3}, \\n\" \\\n            \"    {%4, %5, %6, %7}, \\n\" \\\n            \"    {%8, %9}, \\n\" \\\n            \"    {%0, %1, %2, %3}; \\n\" \\\n                    : \"+f\"(  elt(4)), \"+f\"(  elt(5)), \"+f\"(  elt(6)), \"+f\"(  elt(7))\n                    :  \"r\"(a.reg(0)),  \"r\"(a.reg(1)),  \"r\"(a.reg(2)),  \"r\"(a.reg(3))\n                    ,  \"r\"(b.reg(2)),  \"r\"(b.reg(3)));\n    }\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Fragment, int M, int N >\ninline __device__ void clear(Fragment (&frag)[M][N]) {\n    #pragma unroll\n    for( int mi = 0; mi < M; ++mi ) {\n        #pragma unroll\n        for( int ni = 0; ni < N; ++ni ) {\n            frag[mi][ni].clear();\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Accumulator_type, int WARPS_K >\nstruct Clear_accumulator {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int WARPS_K >\nstruct Clear_accumulator<float, WARPS_K> {\n  template< typename Acc, int M, int N >\n  static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {\n    fmha::clear(acc);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Acc, typename A, typename B, int M, int N>\ninline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {\n\n    #pragma unroll\n    for( int mi = 0; mi < M; ++mi ) {\n        #pragma unroll\n        for( int ni = 0; ni < N; ++ni ) {\n            acc[mi][ni].mma(a[mi], b[ni]);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The number of rows in the CTA tile.\n    int M_,\n    // The number of cols in the CTA tile.\n    int N_,\n    // The number of elements in the the K dimension of the GEMM loop.\n    int K_,\n    // The number of rows of warps.\n    int WARPS_M_,\n    // The number of cols of warps.\n    int WARPS_N_,\n    // The number of warps in the K dimension of the GEMM loop.\n    int WARPS_K_>\nstruct Cta_tile_ {\n\n    enum { M = M_, N = N_, K = K_ };\n    // The number of warps.\n    enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ };\n    // The number of warps per CTA.\n    enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };\n    // The number of threads per warp.\n    enum { THREADS_PER_WARP = 32 };\n    // The number of threads per CTA.\n    enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile>\nstruct Hmma_tile {\n    // The number of elements computed with a single warp-MMA.\n    enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };\n\n    // The number of elements computed with a single CTA-MMA.\n    enum {\n        M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,\n        N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,\n        K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K\n    };\n\n    // The number of MMAs needed to compute the GEMM.\n    enum {\n        MMAS_M = Div_up<Cta_tile::M, M_PER_MMA_PER_CTA>::VALUE,\n        MMAS_N = Div_up<Cta_tile::N, N_PER_MMA_PER_CTA>::VALUE,\n        MMAS_K = Div_up<Cta_tile::K, K_PER_MMA_PER_CTA>::VALUE,\n    };\n\n    // The number of elements computed per warp.\n    enum {\n        M_PER_WARP = MMAS_M * M_PER_MMA,\n        N_PER_WARP = MMAS_N * N_PER_MMA,\n        K_PER_WARP = MMAS_K * K_PER_MMA,\n    };\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nusing A_type = uint16_t;\nusing B_type = uint16_t;\nusing C_type = uint16_t;\nusing Accumulator_type = float;\nusing Epilogue_type = float;\n\nconstexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;\nconstexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;\nconstexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>\nusing Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile_>\nusing Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,\n                                                   Cta_tile_::N,\n                                                   Next_power_of_two<Cta_tile_::K>::VALUE,\n                                                   Cta_tile_::WARPS_M,\n                                                   Cta_tile_::WARPS_N,\n                                                   Cta_tile_::WARPS_K>;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The number of bits per element.\n    int BITS_PER_ELEMENT,\n    // The number of rows of Q, K or V loaded by this tile.\n    int ROWS,\n    // The number of columns.\n    int COLS,\n    // The number of matrics.\n    int NUM_MATS = 3\n>\nstruct Gmem_tile_qkv {\n\n    // The size of each LDG.\n    enum { BYTES_PER_LDG = 16 };\n    // The size of a row in bytes.\n    enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 };\n\n    // The number of threads to load a \"row\" of the matrix.\n    enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG };\n\n    // The number of \"rows\" loaded per LDG.\n    enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // The number of LDGs needed to load a chunk of the Q matrix.\n    enum { LDGS = fmha::Div_up<ROWS, ROWS_PER_LDG>::VALUE };\n\n    // Ctor.\n    template< typename Params, typename BInfo >\n    inline __device__ Gmem_tile_qkv(const Params &params, int qkv_offset, const BInfo &binfo, int tidx)\n        : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)\n        , actual_seqlen(binfo.actual_seqlen)\n        , qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {\n\n        // Compute the position in the sequence (within the CTA for the moment).\n        int row = tidx / THREADS_PER_ROW;\n        // Compute the position of the thread in the row.\n        int col = tidx % THREADS_PER_ROW;\n\n        // Store the row as we need it to disable the loads.\n        row_ = row;\n\n        // The row offset in the batched GEMM. For each seq element, we store QKV in that order.\n        int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;\n        // Add the block index.\n        row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;\n\n        // Assemble the final pointer.\n        qkv_ptr_ += row_offset + col * BYTES_PER_LDG;\n    }\n\n    // Store data to shared memory.\n    template< typename Smem_tile >\n    inline __device__ void commit(Smem_tile &smem_tile) {\n        smem_tile.store(fetch_);\n    }\n\n    // Load data from memory.\n    template< typename Smem_tile >\n    inline __device__ void load(Smem_tile &smem_tile) {\n        const void *ptrs[LDGS];\n        uint32_t preds[LDGS];\n        #pragma unroll\n        for( int ii = 0; ii < LDGS; ++ii ) {\n            ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;\n            preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));\n            fetch_[ii] = make_uint4(0, 0, 0, 0);\n        }\n\n        // not packing predicates removes restrictions (e.g. FP16 384, 4 warps)\n        Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);\n        #pragma unroll\n        for( int ii = 0; ii < LDGS; ++ii ) {\n            fct.load(ii, preds[ii]);\n        }\n    }\n\n    // Store data to memory.\n    inline __device__ void store(const uint4 (&data)[LDGS]) {\n        #pragma unroll\n        for( int ii = 0; ii < LDGS; ++ii ) {\n            char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;\n            if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {\n                fmha::stg(ptr, data[ii]);\n            }\n        }\n    }\n\n    // Move the pointer to the next location.\n    inline __device__ void move() {\n        qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;\n        actual_seqlen -= ROWS;\n    }\n\n    // The stride between rows for the QKV matrice.\n    int64_t params_qkv_stride_in_bytes_;\n    // The pointer.\n    char *qkv_ptr_;\n    // The fetch registers.\n    uint4 fetch_[LDGS];\n    // Keep track of the row the thread is processing as we move the tile.\n    int row_;\n    // The length of the sequence loaded by that memory tile.\n    int actual_seqlen;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile >\nstruct Gmem_tile_o {\n\n    // The mma tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    // The size of each element.\n    enum { BYTES_PER_ELEMENT = 2 };\n    // The size of a row in bytes.\n    enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT };\n\n    // The number of threads to store a \"row\" of the matrix.\n    enum { THREADS_PER_ROW = 16 };\n    // The size of each STG.\n    enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW };\n\n    // The number of \"rows\" stored per iteration of the loop. The output of 1 MMA.\n    enum { ROWS = Cta_tile::M };\n    // The number of \"rows\" stored per iteration of the loop. The output of 1 MMA.\n    enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };\n    // The number of outter loop for the stores.\n    enum { LOOPS = ROWS / ROWS_PER_LOOP };\n\n    // The number of \"rows\" stored per STG.\n    enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // Do we have to guard against partial writes/reads.\n    enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 };\n    // The number of STGs needed to store a chunk of the Q matrix.\n    enum { STGS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_STG>::VALUE };\n    // The number of STGs needed to store a chunk of the Q matrix in total.\n    enum { STGS = STGS_PER_LOOP * LOOPS };\n\n    // Ctor.\n    template<typename Params, typename BInfo>\n    inline __device__ Gmem_tile_o(const Params &params, const BInfo &binfo, int tidx)\n        : params_o_stride_in_bytes_(params.o_stride_in_bytes)\n        , actual_seqlen_(binfo.actual_seqlen)\n        , o_ptr_(reinterpret_cast<char *>(params.o_ptr)) {\n\n        // Compute the position in the sequence (within the CTA for the moment).\n        int row = tidx / THREADS_PER_ROW;\n        // Compute the position of the thread in the row.\n        int col = tidx % THREADS_PER_ROW;\n\n        // Store the row as we need it to disable loads.\n        row_ = row;\n\n        // The row offset in the batched GEMM.\n        int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;\n        // Assemble the final pointer.\n        o_ptr_ += row_offset + col * BYTES_PER_STG;\n\n        // Is that thread active on the last STG?\n        if( HAS_INCOMPLETE_STG ) {\n            is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;\n        }\n    }\n\n    // Store data to global memory.\n    inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {\n\n        #pragma unroll\n        for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {\n            int jj = mi * STGS_PER_LOOP + ii;\n            if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {\n                break;\n            }\n\n            float x = reinterpret_cast<const float &>(src[ii].x);\n            float y = reinterpret_cast<const float &>(src[ii].y);\n            float z = reinterpret_cast<const float &>(src[ii].z);\n            float w = reinterpret_cast<const float &>(src[ii].w);\n            uint2 out = float4_to_half4(x, y, z, w);\n            if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {\n                fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out);\n            }\n        }\n    }\n\n    // Move the pointer to the next location.\n    inline __device__ void move() {\n        row_ += ROWS;\n        o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;\n    }\n\n    // The stride between rows for the QKV matrice.\n    int64_t params_o_stride_in_bytes_;\n    // The pointer.\n    char *o_ptr_;\n    // Is the thread active for the last STG?\n    int is_active_for_last_stg_;\n    // Keep track of the row to disable loads.\n    int row_;\n    // The length of the sequence loaded by that memory tile.\n    int actual_seqlen_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile, int BYTES_PER_ELEMENT >\nstruct Gmem_tile_mma_sd {\n\n    // The mma tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    // Each STG stores 8 elements.\n    enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };\n    // The number of MMAs in the M dimension.\n    enum { MMAS_M = Mma_tile::MMAS_M };\n    // The number of MMAs in the N dimension.\n    enum { MMAS_N = Mma_tile::MMAS_N };\n    // The number of rows computed per MMA per thread block.\n    enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA };\n    // The number of cols computed per MMA per thread block.\n    enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA };\n    // The number of threads per block.\n    enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA };\n    // The size of each row in bytes. I.e. how many bytes are stored per STG.\n    enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG };\n    // The fixed sequence length.\n    enum { SEQLEN = Cta_tile::N };\n    // The distance between two blocks (in bytes).\n    enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT };\n    // The distance between elements stored per loop (in bytes).\n    enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW };\n\n    // The type of elements stored per STG.\n    using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;\n\n    // Ctor.\n    template<typename Params>\n    inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int tidx) \n        : ptr_(static_cast<char *>(ptr)) {\n\n        // The block index for the batch.\n        const int bidb = blockIdx.y;\n        // The block index for the head.\n        const int bidh = blockIdx.x;\n        // The block index.\n        size_t bidx = bidb * params.h + bidh;\n\n        // Set store location for each thread at the beginning of the loop\n        ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG;\n    }\n\n    // Store to global memory.\n    inline __device__ void store(const Type &data, const int mi, const int ni) {\n        size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;\n        fmha::stg(ptr_ + offset, data);\n    }\n\n    // Load from global memory.\n    inline __device__ void load(Type &data, const int mi, const int ni) {\n        size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;\n        fmha::ldg(data, ptr_ + offset);\n    }\n\n    // Move to the next tile.\n    inline __device__ void move() {\n        ptr_ += LOOP_STRIDE_BYTES;\n    }\n\n    // The pointer in global memory.\n    char *ptr_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >\nstruct Gmem_tile_mma_s : public Base {\n\n    // The number of mmas in the vertical dimension.\n    enum { M = Base::MMAS_M };\n    // The number of mmas in the horizontal dimension.\n    enum { N = Base::MMAS_N };\n    // The type of the vectors stored by each STG.\n    using Type = typename Base::Type;\n\n    // Ctor.\n    template< typename Params >\n    inline __device__ Gmem_tile_mma_s(void *ptr, const Params &params, const int tidx) \n        : Base(ptr, params, tidx) {\n    }\n\n    // Store to global memory.\n    template<typename Mask>\n    inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n\n                float tmp00 = softmax[2 * mi + 0][4 * ni + 0];\n                float tmp01 = softmax[2 * mi + 0][4 * ni + 1];\n                float tmp02 = softmax[2 * mi + 0][4 * ni + 2];\n                float tmp03 = softmax[2 * mi + 0][4 * ni + 3];\n\n                float tmp10 = softmax[2 * mi + 1][4 * ni + 0];\n                float tmp11 = softmax[2 * mi + 1][4 * ni + 1];\n                float tmp12 = softmax[2 * mi + 1][4 * ni + 2];\n                float tmp13 = softmax[2 * mi + 1][4 * ni + 3];\n\n                uint4 dst;\n                dst.x = fmha::float2_to_half2(tmp00, tmp01);\n                dst.y = fmha::float2_to_half2(tmp02, tmp03);\n                dst.z = fmha::float2_to_half2(tmp10, tmp11);\n                dst.w = fmha::float2_to_half2(tmp12, tmp13);\n                if( mask.is_valid(mi, ni, 0, 0) ) {\n                    Base::store(dst, mi, ni);\n                }\n            }\n        }\n    }\n\n    // Load from global memory.\n    template<typename Mask>\n    inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                regs[mi][ni] = make_uint4(0, 0, 0, 0);\n                if( mask.is_valid(mi, ni, 0, 0) ) {\n                    Base::load(regs[mi][ni], mi, ni);\n                }\n            }\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The base class.\n    typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>\n>\nstruct Gmem_tile_dout : public Base {\n\n    // Ctor.\n    template<typename Params, typename BInfo>\n    inline __device__ Gmem_tile_dout(const Params &params, const BInfo &binfo, int tidx)\n        : Base(params, 0, binfo, tidx) {\n\n        this->qkv_ptr_ = reinterpret_cast<char *>(params.o_ptr);\n        this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes;  // needed for move\n\n        // Compute the position of the thread in the row.\n        int col = tidx % Base::THREADS_PER_ROW;\n\n        // The row offset in the batched GEMM. For each seq element, we store O in that order.\n        int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;\n\n        // Assemble the final pointer.\n        this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >\nstruct Gmem_tile_dq : public Base {\n\n    // Ctor.\n    template<typename Params, typename BInfo>\n    inline __device__ Gmem_tile_dq(const Params &params, const BInfo &binfo, int tidx) \n        : Base(params, binfo, tidx) {\n        this->o_ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);\n        this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes;  // needed for move\n\n        // Compute the position of the thread in the row.\n        int col = tidx % Base::THREADS_PER_ROW;\n\n        // The row offset in the batched GEMM. For each seq element, we store O in that order.\n        int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +\n                             (binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;\n\n        // Assemble the final pointer.\n        this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x8u>\nstruct FMHA_kernel_traits {\n\n    // The CTA description for the 1st GEMM.\n    using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;\n    // The CTA description for the 2nd GEMM.\n    using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;\n\n    // Do we use one buffer for K and V.\n    enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u };\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;\n\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;\n    // The shared memory tile for O.\n    using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;\n\n    // The global memory tile to load/store S.\n    using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;\n\n    // The shared memory tile to transpose S.\n    using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;\n\n    using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>;\n\n    // Make sure the number of threads match.\n    static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, \"\");\n\n    // The number of threads.\n    enum { THREADS = Cta_tile_p::THREADS_PER_CTA };\n    // Make sure the number of threads matches both CTAs.\n    static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, \"\");\n\n    // The amount of shared memory needed to load Q and K.\n    enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE };\n    // The extra amount of shared memory needed to load V.\n    enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE };\n    // The amount of shared memory needed for Q, K and V..\n    enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V };\n    // The amount of shared memory needed to load Q and store O.\n    enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE };\n\n    // The amount of shared memory needed for Q, K, V and O.\n    enum { BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE };\n    // Make sure we have enough shared memory.\n    static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, \"\");\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha/mask.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n\ntemplate<typename Cta_tile>\nstruct Mask {\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    template<typename Params, typename BInfo>\n    __device__ Mask(const Params &params, const BInfo &blockInfo, int tidx) {\n\n        actual_seqlen = blockInfo.actual_seqlen;\n\n        const int warp = tidx / Cta_tile::THREADS_PER_WARP;\n        const int lane = tidx % Cta_tile::THREADS_PER_WARP;\n\n        static_assert(Cta_tile::WARPS_K == 1, \"\");\n\n        // find the warp in the Cta tile\n        const int warp_n = (warp / Cta_tile::WARPS_M);\n        const int warp_m = (warp % Cta_tile::WARPS_M);\n        // decompose warp into 8x4 tile\n        const int quad = lane / 4;\n        const int tid = (lane % 4) * 2;\n        row = warp_m * 16 + quad;\n        col = warp_n * 16 + tid;\n    }\n\n    inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {\n\n        // ii and jj iterate over the 2x4 fragment\n        const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen;\n        //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;\n        return col_valid;\n        // return row_valid && col_valid;\n    }\n\n    inline __device__ void load(int it) {\n        row_offset = it * Cta_tile::M + row;\n    }\n    int row_offset;\n\n    int row;\n    int col;\n    int actual_seqlen;\n};\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha/smem_tile.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/utils.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< \n    // The description of the tile computed by this CTA.\n    typename Cta_tile, \n    // The number of rows in the 2D shared memory buffer.\n    int M_, \n    // The number of cols.\n    int N_, \n    // The size in bits of each element.\n    int BITS_PER_ELEMENT_, \n    // The number of bytes per STS.\n    int BYTES_PER_STS_ = 16,\n    // The number of buffers. (Used in multistage and double buffer cases.)\n    int BUFFERS_PER_TILE_ = 1,\n    // Do we enable the fast path for LDS.128 and friends.\n    int ENABLE_LDS_FAST_PATH_ = 0, \n    // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. \n    int ROWS_PER_XOR_PATTERN_ = 8,\n    // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. \n    int COLS_PER_XOR_PATTERN_ = 1,\n    // Use or not predicates\n    bool USE_PREDICATES_ = true\n>\nstruct Smem_tile_without_skews {\n\n    // The size in bits of each element.\n    enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };\n    // The size in bytes of a single STS.\n    enum { BYTES_PER_STS = BYTES_PER_STS_ };\n    // The number of elements per STS.\n    enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };\n    // To support arbitrary N, we pad some values to a power-of-2.\n    enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE }; \n    // The number of bytes per row without packing of rows.\n    enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };\n    // The number of bytes per row -- we want at least 128B per row.\n    enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };\n    // The number of rows in shared memory (two rows may be packed into a single one).\n    enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };\n\n    // The number of threads per row.\n    enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };\n    // The number of threads per row.\n    enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };\n\n    // The number of STS per row.\n    enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };\n    // It must be at least one.\n    static_assert(STS_PER_ROW >= 1, \"\");\n    // The number of rows written with a single STS.\n    enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)\n    static_assert(ROWS_PER_STS >= 1, \"\");\n    // The number of STS needed to store all rows.\n    enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };\n    // The number of STS in total.\n    enum { STS = STS_PER_COL * STS_PER_ROW };\n\n    // The size of one buffer in bytes in shared memory.\n    enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };\n    // The number of buffers. \n    enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };\n    // The size in bytes of total buffers.\n    enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };\n    // The boundary for smem_read_offset and smem_write_offset increment.\n    enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };\n\n    // Do we enable the LDS.128 fast path?\n    enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };\n    static_assert(ENABLE_LDS_FAST_PATH == 0);\n    // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. \n    enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };\n    // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. \n    enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };\n    // Use or not predicates\n    enum { USE_PREDICATES = USE_PREDICATES_ };\n\n    // The type of elements that are stored in shared memory by each thread.\n    using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;\n\n    // Ctor.\n    inline __device__ Smem_tile_without_skews(void *smem, int tidx) \n        : smem_(__nvvm_get_smem_pointer(smem)) {\n\n        // The row written by a thread. See doc/mma_smem_layout.xlsx.\n        int smem_write_row = tidx / THREADS_PER_ROW;\n\n        // The XOR pattern.\n        int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;\n        // Compute the column and apply the XOR pattern.\n        int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;\n\n        // The offset.\n        this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;\n\n        // TODO: Why not merge it with the read offset?\n        this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);\n        this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);\n    }\n\n    // Compute the store pointers.\n    template< int N >\n    inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {\n        #pragma unroll\n        for( int ii = 0; ii < N; ++ii ) {\n            // Decompose the STS into row/col.\n            int row = ii / STS_PER_ROW;\n            int col = ii % STS_PER_ROW;\n\n            // Assemble the offset.\n            int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;\n\n            // Take the column into account.\n            if( STS_PER_ROW > 1 ) {\n                offset += col*THREADS_PER_ROW*BYTES_PER_STS; \n            }\n\n            // Apply the XOR pattern if needed.\n            if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {\n                const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;\n                offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;\n            }\n\n            // Assemble the final pointer :)\n            ptrs[ii] = smem_ + offset + smem_write_buffer_;\n        }\n    }\n\n    inline __device__ void debug_reset() {\n        for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {\n        for( int row = 0; row < ROWS; ++row ) {\n            for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {\n                if( threadIdx.x == 0 ) {\n                    uint32_t val = 0x0;\n                    sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);\n                }\n            }\n        }\n        }\n    }\n\n    // Print the content of the tile (only for debug ;)).\n    inline __device__ void debug_print() const {\n        for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {\n        for( int row = 0; row < ROWS; ++row ) {\n            for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {\n                if( threadIdx.x == 0 ) {\n                    uint32_t val;\n                    lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);\n                    printf(\"block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\\n\",\n                        blockIdx.x,\n                        blockIdx.y,\n                        blockIdx.z,\n                        smem_,\n                        buffer,\n                        row,\n                        col,\n                        val);\n                }\n            }\n        }\n        }\n    }\n\n    // Move the read offset to next buffer.\n    inline __device__ void move_to_next_read_buffer() {\n        if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {\n            this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;\n        } else if( BUFFERS_PER_TILE > 1 ) {\n            this->smem_read_buffer_ += BYTES_PER_BUFFER;\n        }\n    }\n\n    // Move the read offset to next buffer. TODO: Remove this member function!!!\n    inline __device__ void move_next_read_buffer() {\n        this->move_to_next_read_buffer();\n    }\n\n    // Move the read offset to next N buffer (circular-buffer).\n    inline __device__ void move_to_next_read_buffer(int N) {\n        if( BUFFERS_PER_TILE > 1 ) {\n            this->smem_read_buffer_ += N * BYTES_PER_BUFFER;\n            this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;\n        }\n    }\n\n    // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!\n    inline __device__ void move_next_read_buffer(int N) {\n        this->move_to_next_read_buffer(N);\n    }\n\n    // Move the write offset to next buffer.\n    inline __device__ void move_to_next_write_buffer() {\n        if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {\n            this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;\n        } else if( BUFFERS_PER_TILE > 1 ) {\n            this->smem_write_buffer_ += BYTES_PER_BUFFER;\n        }\n    }\n\n    // Move the write offset to next buffer. TODO: Remove that member function!\n    inline __device__ void move_next_write_buffer() {\n        this->move_to_next_write_buffer();\n    }\n\n    // Move the read offset.\n    inline __device__ void move_read_offset(int delta) {\n        this->smem_read_offset_ += delta;\n    }\n\n    // Move the write offset.\n    inline __device__ void move_write_offset(int delta) {\n        this->smem_write_offset_ += delta;\n    }\n\n    // Store to the tile in shared memory.\n    template< int N >\n    inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {\n        uint32_t smem_ptrs[N];\n        this->compute_store_pointers(smem_ptrs);\n        sts(smem_ptrs, data);\n    }\n\n    // Store to the tile in shared memory.\n    template< int N, int M >\n    inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {\n        uint32_t smem_ptrs[N];\n        this->compute_store_pointers(smem_ptrs);\n        sts(smem_ptrs, data, preds);\n    }\n\n    // Store to the tile in shared memory.\n    template< int N >\n    inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { \n        this->store(data, preds);\n    }\n\n    // Store to the tile in shared memory.\n    template< int N >\n    inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {\n        uint32_t tmp[1] = { preds };\n        this->store(gmem_ptrs, tmp);\n    }\n\n    // The shared memory pointer.\n    uint32_t smem_;\n    // The read offset. Reserve 4 offsets if needed.\n    int smem_read_offset_;\n    // The write offset.\n    int smem_write_offset_;\n    // The buffer base offset for read.\n    int smem_read_buffer_;\n    // The buffer base offset for write.\n    int smem_write_buffer_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< \n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile, \n    // The layout of the tile.\n    typename Layout, \n    // The size of the STS.\n    int BYTES_PER_STS = 16,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE = 1,\n    // Use or not predicates\n    bool USE_PREDICATES = true\n>\nstruct Smem_tile_a {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int MMAS_K, int MMAS_K_WITH_PADDING >\nstruct Compute_reset_mask {\n    // The potential mask.\n    enum { HALF = MMAS_K_WITH_PADDING / 2 };\n    // The remainder.\n    enum { MOD = MMAS_K % HALF };\n    // The final value.\n    enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int MMAS_K_WITH_PADDING >\nstruct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {\n    enum { VALUE = 0 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int MMAS_K >\nstruct Compute_reset_mask<MMAS_K, MMAS_K> {\n    enum { VALUE = MMAS_K - 1 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_a {\n    // The size in bits.\n    enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };\n    // The number of rows.\n    enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE\n>\nstruct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,\n                                                               Cta_tile::M,\n                                                               Cta_tile::K,\n                                                               fmha::BITS_PER_ELEMENT_A,\n                                                               BYTES_PER_STS,\n                                                               BUFFERS_PER_TILE,\n                                                               0,\n                                                               ROWS_PER_XOR_PATTERN_,\n                                                               1> {\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile,\n                                         Cta_tile::M,\n                                         Cta_tile::K,\n                                         fmha::BITS_PER_ELEMENT_A,\n                                         BYTES_PER_STS,\n                                         BUFFERS_PER_TILE,\n                                         0,\n                                         ROWS_PER_XOR_PATTERN_,\n                                         1>;\n    // The fragment.\n    using Fragment = Fragment_a<Row>;\n\n    // When we use padding to reach a power of two, special care has to be taken.\n    using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;\n    // The number of MMAs.\n    using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;\n\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = 16 };\n\n    // Ctor.\n    inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {\n\n        // For documentation on the layout, see doc/mma_smem_layout.xlsx.\n\n        // The number of warps.\n        const int WARPS_M = Cta_tile::WARPS_M;\n        const int WARPS_N = Cta_tile::WARPS_N;\n        const int WARPS_K = Cta_tile::WARPS_K;\n\n        static_assert(WARPS_M == 1);\n        static_assert(WARPS_N == 4 || WARPS_N == 8);\n        static_assert(WARPS_K == 1);\n        static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n\n        // The row and column read by the thread.\n        int smem_read_row  = (tidx & 0x0f);\n        int smem_read_col  = (tidx & 0x07);\n        smem_read_col ^= (tidx & 0x10) / 16;\n\n        // The shared memory offset.\n        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;\n    }\n\n    // Rewind smem_read_offset for last LDS phase in main loop.\n    inline __device__ void reverse_smem_read_offset(int ki = 0) {\n        // Undo the pointer increment for the next ni.\n        // Should match the load function below for ki = 0.\n        if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {\n        #pragma unroll\n        for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {\n            // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).\n            int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;\n\n            // Load using LDSM.M88.4.\n            uint4 tmp;\n            ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);\n\n            // Store the value into the fragment.\n            a[mi].reg(0) = tmp.x;\n            a[mi].reg(1) = tmp.y;\n            a[mi].reg(2) = tmp.z;\n            a[mi].reg(3) = tmp.w;\n        }\n\n        // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.\n        static_assert(Mma_tile_with_padding::MMAS_K < 64, \"Not implemented\");\n        if(        Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {\n            this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki %  8 ==  7 ) {\n            this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  8 && ki %  4 ==  3 ) {\n            this->smem_read_offset_ ^=  7 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  4 && ki %  2 ==  1 ) {\n            this->smem_read_offset_ ^=  3 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^=  1 * BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Reset the read offset.\n    inline __device__ void reset_read_offset() {\n        // The number of MMAs in the K dimension.\n        enum { MMAS_K = Mma_tile::MMAS_K };\n        // The number of MMAs in the K dimension when we include padding.\n        enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };\n        // Assemble the mask.\n        enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };\n\n        // Reset the read offset.\n        this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;\n    }\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE\n>\nstruct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>\n    : public Smem_tile_row_a<Cta_tile,\n                                    BYTES_PER_STS,\n                                    BUFFERS_PER_TILE> {\n    // The base class.\n    using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n    // Ctor.\n    inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< \n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile, \n    // The layout of the tile.\n    typename Layout, \n    // The size of the STS.\n    int BYTES_PER_STS = 16,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE = 1,\n    // Use or not predicates\n    bool USE_PREDICATES = true\n>\nstruct Smem_tile_b {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_b {\n    // The size in bits.\n    enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };\n    // The number of rows.\n    enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nstruct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE\n>\nstruct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,\n                                                           Cta_tile::N,\n                                                           Cta_tile::K,\n                                                           fmha::BITS_PER_ELEMENT_B,\n                                                           BYTES_PER_STS,\n                                                           BUFFERS_PER_TILE,\n                                                           0,\n                                                           ROWS_PER_XOR_PATTERN_,\n                                                           1> {\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile,\n                                         Cta_tile::N,\n                                         Cta_tile::K,\n                                         fmha::BITS_PER_ELEMENT_B,\n                                         BYTES_PER_STS,\n                                         BUFFERS_PER_TILE,\n                                         0,\n                                         ROWS_PER_XOR_PATTERN_,\n                                         1>;\n    // The fragment.\n    using Fragment = Fragment_b< Col>;\n\n    // When we use padding to reach a power of two, special care has to be taken.\n    using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;\n    // The number of MMAs.\n    using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;\n\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = 16 };\n\n    // The number of STS per thread\n    enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };\n    // The number of STS per thread must be at least 1.\n    enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };\n\n    // Ctor.\n    inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {\n\n        // For documentation on the layout, see doc/mma_smem_layout.xlsx.\n\n        // The number of warps.\n        const int WARPS_M = Cta_tile::WARPS_M;\n        const int WARPS_N = Cta_tile::WARPS_N;\n        const int WARPS_K = Cta_tile::WARPS_K;\n        static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n        static_assert(WARPS_M == 1);\n        static_assert(WARPS_N == 4 || WARPS_N == 8);\n        static_assert(WARPS_K == 1);\n\n        // The masks to select the warps.\n        const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;\n\n        // The divisor for the warps.\n        const int WARP_DIV_N = WARPS_M *       1 * Cta_tile::THREADS_PER_WARP;\n\n        // The row and column read by the thread.\n        int smem_read_row  = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +\n                             (tidx & 0x07) +\n                             (tidx & 0x10) / 2;\n        int smem_read_col  = (tidx & 0x07);\n        smem_read_col ^= (tidx & 0x08) / 8;\n        // The shared memory offset.\n        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;\n    }\n\n    // Rewind smem_read_offset for last LDS phase in main loop.\n    inline __device__ void reverse_smem_read_offset(int ki = 0) {\n        // Undo the pointer increment for the next ni.\n        // Should match the load function below for ki = 0.\n        if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).\n            int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;\n\n            // Load using LDSM.M88.4.\n            uint4 tmp;\n            ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);\n\n            // Store the value into the fragment.\n            b[ni].reg(0) = tmp.x;\n            b[ni].reg(1) = tmp.y;\n            b[ni].reg(2) = tmp.z;\n            b[ni].reg(3) = tmp.w;\n        }\n\n        // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.\n        static_assert(Mma_tile_with_padding::MMAS_K < 64, \"Not implemented\");\n        if(        Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {\n            this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki %  8 ==  7 ) {\n            this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  8 && ki %  4 ==  3 ) {\n            this->smem_read_offset_ ^=  7 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  4 && ki %  2 ==  1 ) {\n            this->smem_read_offset_ ^=  3 * BYTES_PER_LDS * 2;\n        } else if( Mma_tile_with_padding::MMAS_K >=  2 ) {\n            this->smem_read_offset_ ^=  1 * BYTES_PER_LDS * 2;\n        }\n    }\n\n    // Reset the read offset.\n    inline __device__ void reset_read_offset() {\n        // The number of MMAs in the K dimension.\n        enum { MMAS_K = Mma_tile::MMAS_K };\n        // The number of MMAs in the K dimension when we include padding.\n        enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };\n        // Assemble the mask.\n        enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };\n\n        // Reset the read offset.\n        this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE\n>\nstruct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >\n    : public Smem_tile_col_b<Cta_tile,\n                             BYTES_PER_STS,\n                             BUFFERS_PER_TILE> {\n\n    // The base class.\n    using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n    // Ctor.\n    inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<  int N >\nstruct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,\n    // How many cols to use for the XOR pattern to avoid bank conflicts?\n    int COLS_PER_XOR_PATTERN_ = 1\n>\nstruct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,\n                                                               Cta_tile::K,\n                                                               Cta_tile::N,\n                                                               fmha::BITS_PER_ELEMENT_B,\n                                                               BYTES_PER_STS,\n                                                               BUFFERS_PER_TILE,\n                                                               0,\n                                                               ROWS_PER_XOR_PATTERN_,\n                                                               COLS_PER_XOR_PATTERN_> {\n\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile,\n                                         Cta_tile::K,\n                                         Cta_tile::N,\n                                         fmha::BITS_PER_ELEMENT_B,\n                                         BYTES_PER_STS,\n                                         BUFFERS_PER_TILE,\n                                         0,\n                                         ROWS_PER_XOR_PATTERN_,\n                                         COLS_PER_XOR_PATTERN_>;\n    // The fragment.\n    using Fragment = Fragment_b<Row>;\n\n    // Can we use LDSM? No if the data type is 32-bit large.\n    enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };\n    // The number of elements per LDS.\n    enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };\n\n    // The number of STS per thread\n    enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };\n    // The number of STS per thread must be at least 1.\n    enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };\n\n    // Ctor.\n    inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {\n\n        // The number of warps.\n        const int WARPS_M = Cta_tile::WARPS_M;\n        const int WARPS_N = Cta_tile::WARPS_N;\n        const int WARPS_K = Cta_tile::WARPS_K;\n        static_assert(WARPS_K == 1);\n        static_assert(WARPS_M == 4 || WARPS_M == 8);\n        static_assert(WARPS_N == 1);\n\n        // The masks to select the warps.\n        const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;\n        const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;\n\n        // The divisor for the warps.\n        const int WARP_DIV_N = WARPS_M *       1 * Cta_tile::THREADS_PER_WARP;\n        const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;\n\n        // The row/col read by the thread.\n        int smem_read_row, smem_read_col;\n\n        static_assert(USE_LDSMT);\n        static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n\n        smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +\n                        (tidx & 0x07) + (tidx & 0x08);\n        smem_read_col = (tidx & 0x07);\n        smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;\n\n        // The shared memory offset.\n        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;\n\n        // Fill zeroes for group conv\n    }\n\n    // Rewind smem_read_offset for last LDS phase in main loop.\n    inline __device__ void reverse_smem_read_offset(int ki = 0) {\n        // The size of each element in bits.\n        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;\n        // The size in bytes of the data needed to compute an MMA per CTA.\n        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;\n\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Undo the pointer increment for the next ni.\n            // Should match the load function below for ki = 0.\n            if( BYTES_PER_MMA_PER_CTA >= 128 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {\n                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n            }\n        }\n\n        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)\n        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&\n                Mma_tile::MMAS_N % 2 == 1 ) {\n            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n        }\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n        // The size of each element in bits.\n        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;\n        // The size in bytes of the data needed to compute an MMA per CTA.\n        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;\n\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Prepare the offset.\n            int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW;\n                if ( BYTES_PER_MMA_PER_CTA == 32 ) {\n                    offset += this->smem_read_offset_;\n                } else if ( BYTES_PER_MMA_PER_CTA == 64 ) {\n                    offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;\n                } else {\n                    offset += this->smem_read_offset_ + (ni  ) * BYTES_PER_MMA_PER_CTA;\n                }\n\n            // Load the data using LDSM.MT88.2.\n            uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;\n            uint4 tmp;\n            if( USE_LDSMT ) {\n                ldsmt(tmp, ptr);\n            } else {\n                lds(tmp.x, (ptr     ) + 0*Base::BYTES_PER_ROW);\n                lds(tmp.y, (ptr     ) + 4*Base::BYTES_PER_ROW);\n                lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW);\n                lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW);\n            }\n\n            // Store those values in the fragment.\n            b[ni].reg(0) = tmp.x;\n            b[ni].reg(1) = tmp.y;\n            b[ni].reg(2) = tmp.z;\n            b[ni].reg(3) = tmp.w;\n\n            // Move the pointer for the next ni. I expect the compiler to not recompute those.\n            if( BYTES_PER_MMA_PER_CTA >= 128 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {\n                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {\n                // Nothing to do!\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n            }\n        }\n\n        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)\n        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&\n                Mma_tile::MMAS_N % 2 == 1 ) {\n            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE\n>\nstruct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>\n    : public Smem_tile_row_b<Cta_tile,\n                             BYTES_PER_STS,\n                             BUFFERS_PER_TILE> {\n\n    // The base class.\n    using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n    // Ctor.\n    inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile>\nstruct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1> {\n\n    // The base class.\n    using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1>;\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The fragment.\n    using Fragment = Fragment_b< fmha::Col>;\n\n    // The size of a single LDS in bytes.\n    enum { BYTES_PER_LDS = 16 };\n\n    // Ctor.\n    inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {\n\n        // The row/col read by the thread.\n        int read_row, read_col;\n\n        static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));\n\n        read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);\n        read_col = (tidx & 0x07);\n        read_col ^= (tidx & 0x10) / 16;\n\n        // The shared memory offset.\n        this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n    }\n\n    // Load from shared memory.\n    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n#pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n            // Jump by 16 * #warps row.\n            int row = ki * 16 * Cta_tile::WARPS_K;\n\n            // Load the data using LDSM.MT88.2.\n            uint4 tmp;\n            fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW);\n            b[ni].reg(0) = tmp.x;\n            b[ni].reg(1) = tmp.y;\n            b[ni].reg(2) = tmp.z;\n            b[ni].reg(3) = tmp.w;\n\n            // Move the pointer for the next ni. I expect the compiler to not recompute those.\n            if( Mma_tile::MMAS_N == 4 ) {\n                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n            } else {\n                assert(false);  // Not implemented!\n            }\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile>\nstruct Smem_tile_o {\n\n    // The MMA tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    // The accumulators.\n    using Accumulator = fmha::Fragment_accumulator;\n    // The accumulators.\n    using Data_type = typename Accumulator::Data_type;\n\n    // The size of each element.\n    enum { BYTES_PER_ELEMENT = sizeof(Data_type) };\n    // The size of each STS.\n    enum { BYTES_PER_STS = 8 };\n    // The size of each row in shared memory.\n    enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT };\n\n    // The size of each LDS.\n    enum { BYTES_PER_LDS = 16 };\n    enum { THREADS_PER_ROW = 16 };\n\n    // The number of rows.\n    enum { ROWS = Cta_tile::M };\n    // The number of \"rows\" to process per loop iteration (in the \"epilogue\").\n    enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };\n    // The number of outer loops.\n    enum { LOOPS = ROWS / ROWS_PER_LOOP };\n    // Make sure it matches our expectations.\n    static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, \"\");\n\n    // The number of rows loaded per LDS.\n    enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    // Do we have to guard against partial writes/reads.\n    enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 };\n    // The total number of LDS per loop.\n    enum { LDS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_LDS>::VALUE };\n\n    // The amount of shared memory.\n    enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW };\n\n    // The write pointer.\n    uint32_t smem_write_, smem_read_;\n    // Is the thread active for the last LDS of the series?\n    int is_active_for_last_lds_;\n\n    static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);\n    static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, \"\");\n\n    // Ctor.\n    inline __device__ Smem_tile_o(void *smem, int tidx) {\n\n        // Get a 32-bit value for the shared memory address.\n        uint32_t smem_ = __nvvm_get_smem_pointer(smem);\n\n        static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));\n\n        int write_row = (tidx & 0x1c) / 4;\n        int write_col = (tidx);\n\n        // Assemble the write pointer.\n        smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;\n\n        // The element read by each thread.\n        int read_row = tidx / THREADS_PER_ROW;\n        int read_col = tidx % THREADS_PER_ROW;\n\n        // Take the XOR pattern into account for the column.\n        read_col ^= 2 * (read_row & 0x7);\n\n        // Assemble the read pointer.\n        this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n\n        // Is that thread active on the last LDS?\n        if( HAS_INCOMPLETE_LDS ) {\n            this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;\n        }\n    }\n\n    // Load the output fragments.\n    inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {\n        #pragma unroll\n        for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {\n\n            // Load the elements before the reduction (split-K).\n            uint4 tmp[Cta_tile::WARPS_K];\n            #pragma unroll\n            for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {\n                int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;\n                if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {\n                    fmha::lds(tmp[jj], this->smem_read_ + imm);\n                }\n            }\n\n            // Perform the reduction.\n            out[ii] = tmp[0];\n            #pragma unroll\n            for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {\n                out[ii] = fmha::fadd4(out[ii], tmp[jj]);\n            }\n        }\n    }\n    // Store the accumulators.\n    template <int M, int N>\n    inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {\n        enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };\n        #pragma unroll\n        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {\n\n            // The number of MMAs that are stored per loop iteration.\n            enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS };\n\n            // Store 1st column of the different MMAs.\n            #pragma unroll\n            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {\n                // Precompute the immediates to jump between rows.\n                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;\n                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;\n                uint2 tmp0, tmp1;\n                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);\n                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);\n\n                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);\n                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);\n\n                // Store.\n                fmha::sts(this->smem_write_ + row_0, tmp0);\n                fmha::sts(this->smem_write_ + row_1, tmp1);\n            }\n\n            // Swizzle the write pointer using a XOR of 16B.\n            this->smem_write_ ^= 32;\n\n            // Store 2nd column of the different MMAs.\n            #pragma unroll\n            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {\n                // Precompute the immediates to jump between rows.\n                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;\n                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;\n\n                uint2 tmp0, tmp1;\n                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);\n                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);\n\n                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);\n                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);\n                // Store.\n                fmha::sts(this->smem_write_ + row_0, tmp0);\n                fmha::sts(this->smem_write_ + row_1, tmp1);\n            }\n\n            // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.\n            this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile>\nstruct Smem_tile_mma {\n\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n    using Fragment = fmha::Fragment_a<fmha::Col>;\n\n    enum { COLS = Cta_tile::N };\n    enum { BYTES_PER_ELT = 2 };\n    enum { BYTES_PER_STS = 4 };\n    enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT };  // TODO\n    enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };\n\n    enum { WARPS_M = Cta_tile::WARPS_M };\n    enum { WARPS_N = Cta_tile::WARPS_N };\n    enum { WARPS_K = Cta_tile::WARPS_K };\n\n    static_assert(WARPS_K == 1);\n    inline __device__ Smem_tile_mma(char *smem, int tidx) {\n        smem_ = __nvvm_get_smem_pointer(smem);\n\n        int write_col, write_row;\n        static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);\n        if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {\n            write_row = (tidx & 0x1c) / 4;\n            write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);\n        } else {\n            write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;\n            write_col = (tidx & 0x03);\n        }\n        write_col ^= (write_row & 0x07) * 4;\n\n        write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;\n    }\n\n    template<int M, int N>\n    inline __device__ void store(const uint4 (&regs)[M][N]) {\n        static_assert(COLS == Cta_tile::N);\n        for( int mi = 0; mi < M; mi++ ) {\n            for( int ni = 0; ni < N; ni++ ) {\n                size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;\n                fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);\n                fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);\n                offset ^= 4 * BYTES_PER_STS;\n                fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);\n                fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);\n            }\n        }\n    }\n\n    uint32_t smem_;\n    uint32_t write_offset_;\n    uint32_t warp_m;\n    uint32_t warp_n;\n    uint32_t lane;\n};\n\ntemplate< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>\nstruct Smem_tile_mma_transposed : public Base {\n    enum { BYTES_PER_LDS = 16 };\n    enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };\n    enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };\n    enum { WARPS_M = Base::WARPS_M };\n    enum { WARPS_N = Base::WARPS_N };\n    static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));\n    using Fragment = typename Base::Fragment;\n    inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {\n\n        static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));\n        int read_row, read_col;\n        read_row = (tidx & 0x0f);\n        read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;\n\n        read_col ^= (read_row & 0x07);\n        read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n    }\n\n    template<int M, int N>\n    inline __device__ void load(Fragment (&frag)[M][N]) {\n        static_assert(Base::COLS == Cta_tile::N);\n        for( int mi = 0; mi < M; mi++ ) {\n            for( int ni = 0; ni < N; ni++ ) {\n                size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;\n                uint4 dst;\n                fmha::ldsmt(dst, this->smem_ + offset);\n                frag[mi][ni].reg(0) = dst.x;\n                frag[mi][ni].reg(1) = dst.z;  // Fragment A regs col major!\n                frag[mi][ni].reg(2) = dst.y;\n                frag[mi][ni].reg(3) = dst.w;\n            }\n        }\n    }\n\n    uint32_t read_offset_;\n};\n\ntemplate< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>\nstruct Smem_tile_mma_epilogue : public Base {\n    enum { BYTES_PER_LDS = 16 };\n    enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };\n    enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };\n    enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };\n    static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);\n    enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n    enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };\n    static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);\n    enum { WARPS_M = Base::WARPS_M };\n    enum { WARPS_N = Base::WARPS_N };\n    static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);\n    \n    using Acc = fmha::Fragment_accumulator;\n\n    inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {\n        const int read_row = tidx / THREADS_PER_ROW;\n        int read_col = tidx % THREADS_PER_ROW;\n        read_col ^= (read_row & 0x07);\n        read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n    }\n\n    inline __device__ void load(uint4 (&data)[NUM_LDS]) {\n        for( int ii = 0; ii < NUM_LDS; ii++ ) {\n            size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;\n            fmha::lds(data[ii], this->smem_ + offset);\n        }\n    }\n\n    template<int M, int N>\n    inline __device__ void store(const Acc (&acc)[M][N]){\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                // 1st row - 4 elements per row.\n                float tmp00 = acc[mi][ni].elt(0);\n                float tmp01 = acc[mi][ni].elt(1);\n                float tmp02 = acc[mi][ni].elt(4);\n                float tmp03 = acc[mi][ni].elt(5);\n                // 2nd row - 4 elements per row.\n                float tmp10 = acc[mi][ni].elt(2);\n                float tmp11 = acc[mi][ni].elt(3);\n                float tmp12 = acc[mi][ni].elt(6);\n                float tmp13 = acc[mi][ni].elt(7);\n\n                uint32_t x = fmha::float2_to_half2(tmp00, tmp01);\n                uint32_t y = fmha::float2_to_half2(tmp02, tmp03);\n                uint32_t z = fmha::float2_to_half2(tmp10, tmp11);\n                uint32_t w = fmha::float2_to_half2(tmp12, tmp13);\n     \n                size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);\n                offset ^= 4 * Base::BYTES_PER_STS;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);\n            }\n        }\n    }\n\n\n\n    template<int M, int N>\n    inline __device__ void store(const uint4 (&regs)[M][N]) {\n        for( int mi = 0; mi < M; mi++ ) {\n            for( int ni = 0; ni < N; ni++ ) {\n                size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);\n                offset ^= 4 * Base::BYTES_PER_STS;\n                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);\n                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);\n            }\n        }\n    }\n\n    uint32_t read_offset_;\n};\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha/softmax.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Sum_ {\n    enum { IS_SUM = 1 };\n    static inline __device__ float apply(float x, float y) {\n        return x + y;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Max_ {\n    enum { IS_SUM = 0 };\n    static inline __device__ float apply(float x, float y) {\n        return x > y ? x : y;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ float apply_exp_(float x, float max) {\n    return __expf(x - max);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile, typename Kernel_traits>\nstruct Softmax_base {\n\n    // The Mma tile.\n    using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n    // The number of MMAs in M/N dimensions.\n    enum { MMAS_M = Mma_tile::MMAS_M };\n    enum { MMAS_N = Mma_tile::MMAS_N };\n\n    // The number of groups of warp such that we have at most 4 warps writing consecutive elements.\n    enum { GROUPS = fmha::Div_up<Cta_tile::WARPS_N, 4>::VALUE };\n    // The number of elements that we are going to store per row.\n    enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS };\n    // The number of rows.\n    enum { ROWS = Cta_tile::M * GROUPS };\n    // The total number of elements.\n    enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW };\n\n    // Ctor.\n    template<typename Params>\n    inline __device__ Softmax_base(const Params &params, void *smem, int bidb, int tidx)\n        :  // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),\n          smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {\n\n        // Move to the 1st mask loaded by the thread+ tidx;\n        // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);\n\n        // Extract the position in the warp.\n        int warp = tidx / Cta_tile::THREADS_PER_WARP;\n        int lane = tidx % Cta_tile::THREADS_PER_WARP;\n\n        // Decompose the warp index into M and N.\n        int warp_m = warp % Cta_tile::WARPS_M;\n        int warp_n = warp / Cta_tile::WARPS_M;\n\n        // Decompose the warp-n index into group/position-inside-the-group.\n        int warp_g = warp_n / ELEMENTS_PER_ROW;\n        int warp_i = warp_n % ELEMENTS_PER_ROW;\n\n        // The location written by the threads.\n        int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;\n        int write_col = warp_i;\n\n        // Assemble the write pointer.\n        smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];\n\n        // Assemble the read pointer.\n        smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];\n    }\n\n    template<typename Mask>\n    inline __device__ void apply_mask(const Mask &mask) {\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ++ii ) {\n                #pragma unroll\n                for( int ni = 0; ni < MMAS_N; ++ni ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; ++jj ) {\n                        if( !mask.is_valid(mi, ni, ii, jj) ) {\n                            elt_[2 * mi + ii][4 * ni + jj] = -INFINITY;\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    // Apply the exp to all the elements.\n    inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N * 4; ++ni ) {\n                elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);\n            }\n        }\n    }\n\n    // Do a CTA-wide reduction.\n    template<typename Functor>\n    inline __device__ void reduce_1x4(float (&dst)[MMAS_M * 2]) {\n\n#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        if( Functor::IS_SUM ) {\n            // Apply the summation inside the thread.\n            float tmp[MMAS_M * 2][2];\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                tmp[mi][0] = 0.f;\n                tmp[mi][1] = 0.f;\n                #pragma unroll\n                for( int ni = 0; ni < MMAS_N; ++ni ) {\n                    tmp[mi][0] += elt_[mi][4 * ni + 0];\n                    tmp[mi][0] += elt_[mi][4 * ni + 1];\n                    tmp[mi][1] += elt_[mi][4 * ni + 2];\n                    tmp[mi][1] += elt_[mi][4 * ni + 3];\n                }\n                dst[mi] = tmp[mi][0] + tmp[mi][1];\n            }\n        } else\n#endif  // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        {\n            // Apply the functor for each row inside a thread.\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                dst[mi] = elt_[mi][0];\n                #pragma unroll\n                for( int ni = 1; ni < MMAS_N * 4; ++ni ) {\n                    dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);\n                }\n            }\n        }\n\n        // Apply the functor for each row inside each group of 4 threads.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));\n            __syncwarp();\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));\n            __syncwarp();\n        }\n\n        // Store the different values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            if( tidx_ % 4 == 0 ) {\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];\n            }\n        }\n\n        // Make sure the values are in shared memory.\n        __syncthreads();\n\n        // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the\n        // float4.\n        float4 tmp[1];\n        if( tidx_ < Cta_tile::M ) {\n            tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];\n        }\n\n        // Compute the reduction of those 8 values in a binary-tree fashion.\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);\n        tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);\n\n        // Make sure we can write to shared memory.\n        __syncthreads();\n\n        // Store the value back to shared memory.\n        if( tidx_ < Cta_tile::M ) {\n            smem_[tidx_] = tmp[0].x;\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Finally read the values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];\n            dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];\n        }\n    }\n\n    // Do a CTA-wide reduction.\n    template<typename Functor>\n    inline __device__ void reduce_1x8(float (&dst)[MMAS_M * 2]) {\n\n#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        if( Functor::IS_SUM ) {\n            // Apply the summation inside the thread.\n            float tmp[MMAS_M * 2][2];\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                tmp[mi][0] = 0.f;\n                tmp[mi][1] = 0.f;\n                #pragma unroll\n                for( int ni = 0; ni < MMAS_N; ++ni ) {\n                    tmp[mi][0] += elt_[mi][4 * ni + 0];\n                    tmp[mi][0] += elt_[mi][4 * ni + 1];\n                    tmp[mi][1] += elt_[mi][4 * ni + 2];\n                    tmp[mi][1] += elt_[mi][4 * ni + 3];\n                }\n                dst[mi] = tmp[mi][0] + tmp[mi][1];\n            }\n        } else\n#endif  // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)\n        {\n            // Apply the functor for each row inside a thread.\n            #pragma unroll\n            for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n                dst[mi] = elt_[mi][0];\n                #pragma unroll\n                for( int ni = 1; ni < MMAS_N * 4; ++ni ) {\n                    dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);\n                }\n            }\n        }\n\n        // Apply the functor for each row inside each group of 4 threads.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));\n            __syncwarp();\n            dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));\n            __syncwarp();\n        }\n\n        // Store the different values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            if( tidx_ % 4 == 0 ) {\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];\n                smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];\n            }\n        }\n\n        // Make sure the values are in shared memory.\n        __syncthreads();\n\n        // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the\n        // float4.\n        float4 tmp[2];\n        if( tidx_ < Cta_tile::M ) {\n            tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];\n            tmp[1] = reinterpret_cast<const float4 *>(&smem_[1 * ELEMENTS / 2])[tidx_];\n        }\n\n        // Compute the reduction of those 8 values in a binary-tree fashion.\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);\n        tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);\n        tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y);\n        tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w);\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);\n        tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z);\n        tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x);\n\n        // Make sure we can write to shared memory.\n        __syncthreads();\n\n        // Store the value back to shared memory.\n        if( tidx_ < Cta_tile::M ) {\n            smem_[tidx_] = tmp[0].x;\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Finally read the values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];\n            dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];\n        }\n    }\n\n    // Do a CTA-wide reduction.\n    template<typename Functor>\n    inline __device__ void reduce(float (&dst)[MMAS_M * 2]) {\n        static_assert(Cta_tile::WARPS_M == 1 && (Cta_tile::WARPS_N == 4 || Cta_tile::WARPS_N == 8));\n        if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4 ) {\n            reduce_1x4<Functor>(dst);\n        } else if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8 ) {\n            reduce_1x8<Functor>(dst);\n        } else {\n            assert(false);\n        }\n\n        // Make sure we are done reading from shared memory.\n        __syncthreads();\n    }\n\n    // Scale all the elements.\n    inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {\n        // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.\n        float inv_sum[MMAS_M * 2];\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];\n        }\n\n        // Update the values.\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N * 4; ++ni ) {\n                elt_[mi][ni] *= inv_sum[mi];\n            }\n        }\n    }\n\n    // The pointer to the mask.\n    const char *packed_mask_ptr_;\n    // Shared memory for the CTA-wide reduction.\n    float *smem_, *smem_write_, *smem_read_;\n    // The current thread index.\n    int tidx_;\n    // The elements.\n    float elt_[MMAS_M * 2][MMAS_N * 4];\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile, typename Kernel_traits>\nstruct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {\n\n    // The base class.\n    using Base = Softmax_base<Cta_tile, Kernel_traits>;\n    // The fragment.\n    using Fragment_a = fmha::Fragment_a<fmha::Row>;\n\n    static_assert(Fragment_a::NUM_REGS == 4);\n\n    // The MMAs.\n    enum { MMAS_M = Base::MMAS_M };\n    enum { MMAS_N = Base::MMAS_N };\n\n    // The accumulators.\n    using Accumulator = fmha::Fragment_accumulator;\n    using Accumulator_out = Fragment<uint16_t, 8>;\n    static_assert(Accumulator_out::NUM_REGS == 4);\n\n    static_assert(std::is_same<Accumulator::Data_type, float>::value);\n\n    // Ctor.\n    template<typename Params>\n    inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx)\n        : Base(params, smem, bidb, tidx), params_scale_bmm1_(params.scale_bmm1) {\n    }\n\n    // Store the tile after softmax.\n    template<typename Gmem_tile>\n    inline __device__ void store(Gmem_tile &gmem_tile) {\n        Accumulator_out acc[MMAS_M][MMAS_N];\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N; ++ni ) {\n\n                // The elements.\n                float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0];\n                float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1];\n                float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2];\n                float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3];\n                float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0];\n                float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1];\n                float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2];\n                float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3];\n\n                // Transform to accumulators.\n                acc[mi][ni].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);\n                acc[mi][ni].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);\n                acc[mi][ni].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);\n                acc[mi][ni].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);\n            }\n        }\n\n        // Delegate to the gmem tile to store.\n        gmem_tile.store(acc);\n    }\n\n    // Pack the data to a fragment for the next GEMM.\n    template<int K, int M>\n    inline __device__ void pack(Fragment_a (&dst)[K][M]) const {\n        #pragma unroll\n        for( int mi = 0; mi < M; ++mi ) {\n            #pragma unroll\n            for( int ki = 0; ki < K; ++ki ) {\n\n                // 1st row - 4 elements per row.\n                float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];\n                float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];\n                float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];\n                float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];\n\n                // 2nd row - 4 elements per row.\n                float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];\n                float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];\n                float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];\n                float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];\n\n                // Pack to 4 registers.\n                dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);\n                dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);\n                dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);\n                dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);\n            }\n        }\n    }\n\n    // Scale FP32 fragments\n    inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {\n        const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);\n\n        #pragma unroll\n        for( int mi = 0; mi < MMAS_M; ++mi ) {\n            #pragma unroll\n            for( int ni = 0; ni < MMAS_N; ++ni ) {\n                // 1st row - 4 elements per row.\n                this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;\n                this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;\n                this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;\n                this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;\n                // 2nd row - 4 elements per row.\n                this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;\n                this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;\n                this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;\n                this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;\n            }\n        }\n    }\n    const uint32_t params_scale_bmm1_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha/utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdint.h>\n#include <stdlib.h>\n\nextern \"C\" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Row {};  \nstruct Col {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int M, bool = (M & (M-1)) == 0 >\nstruct Next_power_of_two {\n};\n\ntemplate< int M >\nstruct Next_power_of_two<  M, true > { enum { VALUE =   M }; };\ntemplate<>\nstruct Next_power_of_two<  3, false> { enum { VALUE =   4 }; };\ntemplate<>\nstruct Next_power_of_two<  5, false> { enum { VALUE =   8 }; };\ntemplate<>\nstruct Next_power_of_two<  6, false> { enum { VALUE =   8 }; };\ntemplate<>\nstruct Next_power_of_two<  7, false> { enum { VALUE =   8 }; };\ntemplate<>\nstruct Next_power_of_two<  9, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 10, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 11, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 12, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 13, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 14, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 15, false> { enum { VALUE =  16 }; };\ntemplate<>\nstruct Next_power_of_two< 24, false> { enum { VALUE =  32 }; };\ntemplate<>\nstruct Next_power_of_two< 48, false> { enum { VALUE =  64 }; };\ntemplate<>\nstruct Next_power_of_two< 80, false> { enum { VALUE = 128 }; };\ntemplate<>\nstruct Next_power_of_two< 96, false> { enum { VALUE = 128 }; };\ntemplate<>\nstruct Next_power_of_two<112, false> { enum { VALUE = 128 }; };\ntemplate<>\nstruct Next_power_of_two<144, false> { enum { VALUE = 256 }; };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, bool = (N & (N-1)) == 0 >\nstruct Prev_power_of_two {\n};\n\ntemplate< int N >\nstruct Prev_power_of_two< N, true > { enum { VALUE = N }; };\ntemplate<>\nstruct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; };\ntemplate<>\nstruct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; };\ntemplate<>\nstruct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; };\ntemplate<>\nstruct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int M, int N >\nstruct Div_up {\n    enum { VALUE = (M + N-1) / N };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int A, int B >\nstruct Max {\n    enum { VALUE = A >= B ? A : B };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int A, int B, int C >\nstruct Max_3 {\n    enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int A, int B >\nstruct Min {\n    enum { VALUE = A <= B ? A : B };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int SIZE_IN_BYTES >\nstruct Uint_from_size_in_bytes {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<1> {\n    using Type = uint8_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<2> {\n    using Type = uint16_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<4> {\n    using Type = uint32_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<8> {\n    using Type = uint2;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Uint_from_size_in_bytes<16> {\n    using Type = uint4;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int WARPS_M, int WARPS_N, int WARPS_K >\nstruct Warp_masks {\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; };\ntemplate<>\nstruct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; };\ntemplate<>\nstruct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; };\ntemplate<>\nstruct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; };\ntemplate<>\nstruct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; };\ntemplate<>\nstruct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; };\ntemplate<>\nstruct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; };\ntemplate<>\nstruct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; };\ntemplate<>\nstruct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename T >\ninline __device__ __host__ T div_up(T m, T n) {\n    return (m + n-1) / n;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline int clz(int x) {\n    for( int i = 31; i >= 0; --i ) {\n        if( (1 << i) & x ) {\n            return 31 - i;\n        }\n    }\n    return 32;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline int find_log_2(int x, bool round_up = false) {\n    int a = 31 - clz(x);\n    if( round_up ) {\n        a += (x & (x-1)) ? 1 : 0;\n    }\n    return a;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {\n    uint32_t c;\n    asm volatile(\"add.f16x2 %0, %1, %2;\\n\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {\n    uint32_t c;\n    asm volatile(\"min.f16x2 %0, %1, %2;\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {\n    uint32_t c;\n    asm volatile(\"mul.f16x2 %0, %1, %2;\\n\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hmul4(uint2 a, uint2 b) {\n    uint2 c;\n    c.x = hmul2(a.x, b.x);\n    c.y = hmul2(a.y, b.y);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hmul8(uint4 a, uint4 b) {\n    uint4 c;\n    c.x = hmul2(a.x, b.x);\n    c.y = hmul2(a.y, b.y);\n    c.z = hmul2(a.z, b.z);\n    c.w = hmul2(a.w, b.w);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hmul8(uint32_t a, uint4 b) {\n    uint4 c;\n    c.x = hmul2(a, b.x);\n    c.y = hmul2(a, b.y);\n    c.z = hmul2(a, b.z);\n    c.w = hmul2(a, b.w);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {\n    uint32_t res;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    asm volatile( \"max.f16x2 %0, %1, %2;\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(lb));\n#else\n    const uint32_t zero = 0u;\n    asm volatile( \\\n        \"{\\n\" \\\n        \"\\t .reg .f16x2 sela;\\n\" \\\n        \"\\t set.gtu.u32.f16x2 sela, %1, %2;\\n\" \\\n        \"\\t and.b32 %0, sela, %1;\\n\" \n        \"}\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(zero));\n#endif\n    return res;\n}\nstatic inline __device__ uint32_t habs2(uint32_t x) {\n    uint32_t res;\n    asm volatile( \"abs.f16x2 %0, %1;\\n\" : \"=r\"(res) : \"r\"(x));\n    return res;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\ntemplate< typename T >\nstatic inline __device__ T clamp(T x, T lb, T ub) {\n    return x < lb ? lb : (x > ub ? ub : x);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t clamp_to_zero(uint16_t x) {\n    uint16_t mask;\n    asm volatile(\"set.gtu %0, %1, 0;\" : \"=h\"(mask) : \"h\"(x));\n    return mask & x;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t float_to_half(float f) {\n    uint16_t h;\n    asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(h) : \"f\"(f));\n    return h;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float2_to_half2(float a, float b) {\n    uint32_t c;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    asm volatile(\"cvt.rn.f16x2.f32 %0, %1, %2;\\n\" : \"=r\"(c) : \"f\"(b), \"f\"(a));\n#else\n    uint16_t lo = float_to_half(a);\n    uint16_t hi = float_to_half(b);\n    asm volatile(\"mov.b32 %0, {%1, %2};\\n\" : \"=r\"(c) : \"h\"(lo), \"h\"(hi));\n#endif\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float_to_half2(float a) {\n    return float2_to_half2(a,a);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float2_to_half2(const float2 &f) {\n    return float2_to_half2(f.x, f.y);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) {\n    uint2 d;\n    d.x = float2_to_half2(x, y);\n    d.y = float2_to_half2(z, w);\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {\n    uint32_t d;\n    asm volatile(\"fma.rn.f16x2 %0, %1, %2, %3;\\n\" : \"=r\"(d) : \"r\"(a), \"r\"(b), \"r\"(c));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) {\n    uint32_t d;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    asm volatile(\"fma.rn.f16x2.relu %0, %1, %2, %3;\" : \"=r\"(d) : \"r\"(a), \"r\"(b), \"r\"(c));\n#else\n    d = hrelu2(hfma2(a, b, c));\n#endif\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t h0_h0(uint32_t x) {\n    uint32_t y;\n    asm volatile(\"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\\n\" \n        : \"=r\"(y) : \"r\"(x)); \n    return y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float h0_to_float(uint32_t h2) {\n    float f;\n    asm volatile(\"{\\n\" \\\n        \".reg .f16 lo, hi;\\n\" \\\n        \"mov.b32 {lo, hi}, %1;\\n\" \\\n        \"cvt.f32.f16 %0, lo;\\n\" \\\n        \"}\\n\" : \"=f\"(f) : \"r\"(h2));\n    return f;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t h1_h1(uint32_t x) {\n    uint32_t y;\n    asm volatile(\"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\\n\" \n        : \"=r\"(y) : \"r\"(x)); \n    return y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {\n    uint16_t d;\n    asm volatile(\"add.f16 %0, %1, %2;\" : \"=h\"(d) : \"h\"(a), \"h\"(b));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hadd(uint32_t a, uint32_t b) {\n    return hadd2(a, b);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hadd4(uint2 a, uint2 b) {\n    uint2 c;\n    c.x = hadd2(a.x, b.x);\n    c.y = hadd2(a.y, b.y);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hadd(uint2 a, uint2 b) {\n    return hadd4(a, b);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hadd8(uint4 a, uint4 b) {\n    uint4 c;\n    c.x = hadd2(a.x, b.x);\n    c.y = hadd2(a.y, b.y);\n    c.z = hadd2(a.z, b.z);\n    c.w = hadd2(a.w, b.w);\n    return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 fadd4(uint4 a, uint4 b) {\n    float4 c;\n    c.x = reinterpret_cast<const float&>(a.x) + reinterpret_cast<const float&>(b.x);\n    c.y = reinterpret_cast<const float&>(a.y) + reinterpret_cast<const float&>(b.y);\n    c.z = reinterpret_cast<const float&>(a.z) + reinterpret_cast<const float&>(b.z);\n    c.w = reinterpret_cast<const float&>(a.w) + reinterpret_cast<const float&>(b.w);\n    return reinterpret_cast<const uint4&>(c);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hadd(uint4 a, uint4 b) {\n    return hadd8(a, b);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float half_to_float(uint16_t h) {\n    float f;\n    asm volatile(\"cvt.f32.f16 %0, %1;\\n\" : \"=f\"(f) : \"h\"(h));\n    return f;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float2 half2_to_float2(uint32_t x) {\n    uint16_t lo, hi;\n    asm volatile(\"mov.b32 {%0, %1}, %2;\\n\" : \"=h\"(lo), \"=h\"(hi) : \"r\"(x));\n    return make_float2(half_to_float(lo), half_to_float(hi));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) {\n    float2 tmp = half2_to_float2(h);\n    x = tmp.x;\n    y = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {\n    uint16_t d;\n    asm volatile(\"fma.rn.f16 %0, %1, %2, %3;\" : \"=h\"(d) : \"h\"(a), \"h\"(b), \"h\"(c));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {\n    uint16_t d;\n    asm volatile(\"mul.f16 %0, %1, %2;\" : \"=h\"(d) : \"h\"(a), \"h\"(b));\n    return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float sigmoid(float x) {\n    return 1.f / (1.f + expf(-x));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint16_t &dst) {\n    dst = uint16_t(0);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint32_t &dst) {\n    dst = 0u;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint2 &dst) {\n    dst = make_uint2(0u, 0u);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint4 &dst) {\n    dst = make_uint4(0u, 0u, 0u, 0u);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// P R E D I C A T E   P A C K I N G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\nenum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE };\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// G E N E R I C   P R E D I C A T E D   L D G S T S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M, typename Functor >\ninline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {\n\n    // The number of complete bytes (where we use all the predicates in a byte).\n    enum { COMPLETE = N / PREDS_PER_BYTE };\n    // Make sure we did allocate enough predicates.\n    static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, \"\");\n    // The remainder.\n    enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };\n    // Make sure we got the math right and the remainder is between 0 and 3.\n    static_assert(REMAINDER >= 0 && REMAINDER <= 3, \"\");\n    // The mask to extract the predicates.\n    enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };\n\n    // Clear the fetch registers.\n    #pragma unroll\n    for( int ii = 0; ii < N; ++ii ) {\n        fct.clear(ii);\n    }\n\n    // Run complete steps.\n    bool p[PREDS_PER_BYTE];\n    #pragma unroll\n    for( int ii = 0; ii < COMPLETE; ++ii ) {\n\n        // The predicate.\n        uint32_t reg = preds[ii / BYTES_PER_REG];\n\n        // Extract the predicates.\n        #pragma unroll\n        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {\n            uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);\n            p[jj] = (reg & mask) != 0u;\n        }\n\n        // Issue the loads.\n        #pragma unroll\n        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {\n            fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);\n        }\n    }\n\n    // Skip the rest of the code if we do not have a remainder.\n    if( REMAINDER > 0 ) {\n\n        // The mask to extract the predicates.\n        enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };\n\n        // The predicate register.\n        uint32_t reg = preds[COMPLETE / BYTES_PER_REG];\n\n        // Extract the predicates.\n        #pragma unroll\n        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {\n            uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);\n            p[jj] = (reg & mask) != 0u;\n        }\n\n        // Issue the loads.\n        #pragma unroll\n        for( int ii = 0; ii < REMAINDER; ++ii ) {\n            fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int M, typename Functor >\ninline __device__ void load_(Functor &fct, uint32_t preds) {\n    uint32_t tmp[1] = { preds };\n    load_<M>(fct, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint8_t &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint8_t*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint16_t &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint16_t*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint32_t &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint32_t*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint2 &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint2*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint4 &dst, const void *ptr) {\n    dst = *reinterpret_cast<const uint4*>(ptr);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type, int N >\nstruct Ldg_functor {\n    // Ctor.\n    inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N])\n        : fetch_(fetch), ptrs_(ptrs) {\n    }\n\n    // Clear the element.\n    inline __device__ void clear(int ii) {\n        fmha::clear(fetch_[ii]);\n    }\n\n    // Trigger the loads.\n    inline __device__ void load(int ii, bool p) {\n        if( p ) {\n            ldg(fetch_[ii], ptrs_[ii]);\n        }\n    }\n\n    // The fetch registers.\n    Data_type (&fetch_)[N];\n    // The pointers.\n    const void* (&ptrs_)[N];\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type, int N, int M >\ninline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    Ldg_functor<Data_type, N> fct(fetch, ptrs);\n    load_<N>(fct, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint8_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint16_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint32_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint2, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N, int M >\ninline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n    ldg_<uint4, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint16_t &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.b16 %0, [%1];\\n\" : \"=h\"(dst) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint32_t &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.b32 %0, [%1];\\n\" : \"=r\"(dst) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint2 &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.v2.b32 {%0, %1}, [%2];\\n\" : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint4 &dst, uint32_t ptr) {\n    asm volatile(\"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\\n\"\n        : \"=r\"(dst.x)\n        , \"=r\"(dst.y)\n        , \"=r\"(dst.z)\n        , \"=r\"(dst.w)\n        :  \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D S M\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\\n\"\n        : \"=r\"(dst) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\\n\"\n        : \"=r\"(dst) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint2 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint2 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint4 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y), \"=r\"(dst.z), \"=r\"(dst.w) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint4 &dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\\n\"\n        : \"=r\"(dst.x), \"=r\"(dst.y), \"=r\"(dst.z), \"=r\"(dst.w) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// S T G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint8_t val) {\n    *reinterpret_cast<uint8_t*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint16_t val) {\n    *reinterpret_cast<uint16_t*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint32_t val) {\n    *reinterpret_cast<uint32_t*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint2 val) {\n    *reinterpret_cast<uint2*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void *ptr, uint4 val) {\n    *reinterpret_cast<uint4*>(ptr) = val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// S T S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint16_t val) {\n    asm volatile(\"st.shared.b16 [%0], %1;\\n\" : : \"r\"(ptr), \"h\"(val));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint32_t val) {\n    asm volatile(\"st.shared.b32 [%0], %1;\\n\" : : \"r\"(ptr), \"r\"(val));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint2 val) {\n    asm volatile(\"st.shared.v2.b32 [%0], {%1, %2};\\n\"\n        :\n        : \"r\"(ptr)\n        , \"r\"(val.x)\n        , \"r\"(val.y));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint4 val) {\n    asm volatile(\"st.shared.v4.b32 [%0], {%1, %2, %3, %4};\\n\"\n        :\n        : \"r\"(ptr)\n        , \"r\"(val.x)\n        , \"r\"(val.y)\n        , \"r\"(val.z)\n        , \"r\"(val.w));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename Data_type, int N >\ninline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) {\n    #pragma unroll\n    for( int ii = 0; ii < N; ++ii ) {\n        sts(ptrs[ii], data[ii]);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) {\n    sts_<uint16_t, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) {\n    sts_<uint32_t, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) {\n    sts_<uint2, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {\n    sts_<uint4, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <cuda.h>\n#include <vector>\n\n#include <ATen/CUDAGeneratorImpl.h>\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n\n#include <fmha_utils.h>\n\n\nconstexpr int TOTAL_DIM = 0;\nconstexpr int THREE_DIM = 1;\nconstexpr int H_DIM = 2;\nconstexpr int D_DIM = 3;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Qkv_params {\n    // The QKV matrices.\n    void *qkv_ptr;\n\n    // The stride between rows of the Q, K and V matrices.\n    size_t qkv_stride_in_bytes;\n\n    // The number of heads.\n    int h;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fused_multihead_attention_fprop_params : public Qkv_params {\n\n    // The dQKV matrices.\n    void *dqkv_ptr;\n\n    // Temporary for dKV.\n    void *dkv_ptr;\n\n    // The O matrix (output).\n    void *o_ptr;\n\n    // The stride between rows of O.\n    int64_t o_stride_in_bytes;\n\n    // The pointer to the S matrix, overwritten by the dP matrix (bwd).\n    void *s_ptr;\n    // The stride between rows of the S matrix.\n    int64_t s_stride_in_bytes;\n\n    // The dimensions.\n    int b, s, d;\n\n    // The scaling factors for the kernel.\n    uint32_t scale_bmm1, scale_softmax, scale_bmm2;\n\n    // array of length b+1 holding starting offset of each sequence.\n    int *cu_seqlens;\n\n    // The dropout probability (probability of keeping an activation).\n    float p_dropout;\n\n    // Scale factor of 1 / (1 - p_dropout).\n    float rp_dropout;\n\n    // Scale factor of 1 / (1 - p_dropout), in half2.\n    uint32_t scale_dropout;\n\n    // Random state.\n    at::PhiloxCudaState philox_args;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\nvoid run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\nvoid run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\nvoid run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);\n\nvoid run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);\n\nvoid run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream); \n\nvoid run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream);\n\nvoid fmha_run_noloop_reduce(void *out,\n                            const void *in,\n                            const int *cu_seqlens,\n                            const int hidden_size,\n                            const int batch_size,\n                            const int total,\n                            const int num_chunks,\n                            cudaStream_t stream);\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 128 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_128_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 256 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_256_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 384 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_384_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n#include \"fmha_dgrad_kernel_1xN_reload_nl.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::compute_dv_1xN<Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\ntemplate<int CHUNKS>\n__global__\nvoid fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){\n    fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);\n    fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 512 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(\n            fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n\nvoid run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream) {\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;\n    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n    static_assert(smem_size_s == 16 * 512 * 2);\n    static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n    constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n    constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n    constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n    auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;\n       \n    if( num_chunks == 2 ) {\n        kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;\n    }else if( num_chunks == 3 ) {\n        kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;\n    } else {\n        assert(false && \"Unsupperted number of chunks\");\n    }\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b, num_chunks);\n\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n\n    FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, typename Params>\ninline __device__ void compute_dv_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dv =\n        fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n    static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);\n    static_assert(Cta_tile_dv::N == 64);\n    static_assert(Cta_tile_dv::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    // using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n    using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n    // The shared memory tile to reload Q as fragment b.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store dV.\n    using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle dV.\n    using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;\n    static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);\n    static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n    using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n    using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_do gmem_q(params, binfo, tidx);  // treating dout as Q\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 2, binfo, tidx);  // treating V as K\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];\n    static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);\n    static_assert(Mma_tile_dv::MMAS_K == 1);\n    smem_qt.load(frag_qt[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n    Softmax softmax(\n        params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);\n\n    enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n    // Load over the entire sequence length.\n    for( int l = 0; l < STEPS; l++ ) {\n        const int loop = l * Cta_tile_p::M;\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Load S\n        uint4 s_regs[M][N];\n        gmem_s.load(s_regs, mask);\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n        // Do this part of P^T = (Q * K^T)^T.\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Store s * dmask to smem for transpose\n        smem_s.store(s_regs);\n\n        // Declare the accumulators for the 1st gemm.\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n        // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe\n        if( l < STEPS - 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n\n\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        float s_mat[2 * M][4 * N];\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);\n            }\n        }\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                        float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];\n                        const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;\n                        const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;\n                        s_dmask = fabsf(s_dmask);\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);\n                    }\n                }\n            }\n        }\n\n        float p_sum[2 * M];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        const float scalef = reinterpret_cast<const float &>(params.scale_softmax);\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;\n                    }\n                }\n            }\n        }\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];\n        smem_s.load(frag_s);\n        for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {\n            for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {\n                for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {\n                    frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);\n                    frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        gmem_s.store(softmax.elt_, mask);\n        gmem_s.move();\n\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dv::MMAS_K;\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n        // Commit the values for Q into shared memory.\n        if(l < STEPS - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure we are reading from the correct buffer.\n        smem_q.move_to_next_read_buffer();\n        smem_qt.move_to_next_read_buffer();\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_q.load(frag_q[0], 0);\n        smem_k.load(frag_k[0], 0);\n        smem_qt.load(frag_qt[0], 0);\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue swizzle for dV\n    Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);\n    smem_dv.store(acc_dv);\n\n    __syncthreads();\n    uint4 dv_out[Smem_tile_dv::NUM_LDS];\n    smem_dv.load(dv_out);\n    Qkv_params dv_params;\n    dv_params.qkv_ptr = params.dqkv_ptr;\n    dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;\n    dv_params.h = params.h;\n    Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);\n    gmem_dv.store(dv_out);\n}\n\ntemplate<typename Kernel_traits, typename Params>\ninline __device__ void compute_dq_dk_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dk =\n        fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n    static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);\n    static_assert(Cta_tile_dk::N == 64);\n    static_assert(Cta_tile_dk::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_v;  // K is used like V in fprop\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    // using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    using Gmem_tile_o = fmha::Gmem_tile_dq<Cta_tile_o>;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store dK.\n    using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle dK.\n    using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;\n    static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);\n    static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);\n\n    // The shared memory tile to reload Q transposed.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n\n\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n    static_assert(M == Mma_tile_o::MMAS_M);\n    static_assert(N == Mma_tile_o::MMAS_K);\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n    // Load dP\n    uint4 s_regs[M][N];\n    gmem_s.load(s_regs, mask);\n    gmem_s.move();\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];\n    smem_qt.load(frag_qt[0], 0);\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);\n\n    // Load over the entire sequence length.\n    for( int l=0;l<STEPS;l++) {\n        const int loop = l * Cta_tile_p::M;\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Pack dP as Fragment_a\n        fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                frag_p[ni][mi].reg(0) = dst.x;  // row 0, cols 0,1\n                frag_p[ni][mi].reg(1) = dst.z;  // row 8, cols 0,1\n                frag_p[ni][mi].reg(2) = dst.y;  // row 0, cols 8,9\n                frag_p[ni][mi].reg(3) = dst.w;  // row 8, cols 8,9\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T. dQ = dP x dK\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_o::MMAS_K;\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Store dP to smem for transpose\n        smem_s.store(s_regs);\n        if(l < STEPS - 1) {\n            // Load next part of S\n            gmem_s.load(s_regs, mask);\n            gmem_s.move();\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];\n        smem_s.load(frag_s);\n\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dk::MMAS_K;\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Commit the values for Q into shared memory.\n        if( l < STEPS - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_qt.load(frag_qt[0], 0);\n        smem_k.load(frag_k[0], 0);\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue swizzle for dK\n    Smem_tile_dk smem_dk(&smem_[0], tidx);\n    smem_dk.store(acc_dk);\n    __syncthreads();\n    uint4 dk_out[Smem_tile_dk::NUM_LDS];\n    smem_dk.load(dk_out);\n    Qkv_params dk_params;\n    dk_params.qkv_ptr = params.dqkv_ptr;\n    dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;\n    dk_params.h = params.h;\n    Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);\n    gmem_dk.store(dk_out);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int CHUNKS, typename Kernel_traits, typename Params>\ninline __device__ void compute_dv_1xN_nl(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dv = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n    static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);\n    static_assert(Cta_tile_dv::N == 64);\n    static_assert(Cta_tile_dv::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n    // The shared memory tile to reload Q as fragment b.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store dV.\n    using Gmem_tile_dv = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o, \n                                             fmha::BITS_PER_ELEMENT_B, \n                                             Cta_tile_p::N, //S, \n                                             Cta_tile_p::K, //D, \n                                             2*CHUNKS>;\n\n    // The shared memory tile to swizzle dV.\n    using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;\n    static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);\n    static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n    using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n    using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the chunk.\n    const int bidc = blockIdx.z;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n    fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_do gmem_q(params, binfo, tidx);  // treating dout as Q\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 2, binfo, tidx);  // treating V as K\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n    Noloop nl_traits(bidc);\n    nl_traits.move_all(gmem_q, gmem_s);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];\n    static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);\n    static_assert(Mma_tile_dv::MMAS_K == 1);\n    smem_qt.load(frag_qt[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n    Softmax softmax(\n        params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);\n\n    // Load over the entire sequence length.\n    for(int l = 0; l < nl_traits.num_steps_;l++) {\n        const int loop = nl_traits.offset_loop_count(l);\n        if( loop >= binfo.actual_seqlen ) break;\n\n        uint4 s_regs[M][N];\n        gmem_s.load(s_regs, mask);\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n        // Do this part of P^T = (Q * K^T)^T.\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n\n        smem_s.store(s_regs);\n\n        // Declare the accumulators for the 1st gemm.\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n        }\n        // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe\n        if(l < nl_traits.num_steps_ - 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        float s_mat[2 * M][4 * N];\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);\n                fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);\n                fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);\n            }\n        }\n\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                         float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];\n                        const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;\n                        const float d_s= drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;\n                        s_dmask = fabsf(s_dmask);\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask);\n                    }\n                }\n            }\n        }\n\n        float p_sum[2 * M];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        const float scalef = reinterpret_cast<const float &>(params.scale_softmax);\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ii = 0; ii < 2; ii++ ) {\n                #pragma unroll\n                for( int ni = 0; ni < N; ni++ ) {\n                    #pragma unroll\n                    for( int jj = 0; jj < 4; jj++ ) {\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;\n                        softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;\n                    }\n                }\n            }\n        }\n\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];\n        smem_s.load(frag_s);\n        for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {\n            for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {\n                for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {\n                    frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);\n                    frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        gmem_s.store(softmax.elt_, mask);\n        gmem_s.move();\n\n        static_assert(Mma_tile_dv::MMAS_K == 1);  // DEBUG\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dv::MMAS_K;\n            fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n        // Commit the values for Q into shared memory.\n        if(l < nl_traits.num_steps_ - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure we are reading from the correct buffer.\n        smem_q.move_to_next_read_buffer();\n        smem_qt.move_to_next_read_buffer();\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_q.load(frag_q[0], 0);\n        smem_k.load(frag_k[0], 0);\n        smem_qt.load(frag_qt[0], 0);\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue for dV = (S * D)' * dout'. We're fully exposed to this!\n\n    // Epilogue swizzle for dV\n    Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);\n    smem_dv.store(acc_dv);\n\n    __syncthreads();\n\n    uint4 dv_out[Smem_tile_dv::NUM_LDS];\n    smem_dv.load(dv_out);\n    Qkv_params dv_params;\n    dv_params.qkv_ptr = params.dkv_ptr;\n    dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);\n    dv_params.h = params.h;\n    Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx);\n    gmem_dv.store(dv_out);\n}\n\ntemplate<int CHUNKS, typename Kernel_traits, typename Params>\ninline __device__ void compute_dq_dk_1xN_nl(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_dk = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n    static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);\n    static_assert(Cta_tile_dk::N == 64);\n    static_assert(Cta_tile_dk::K == 16);\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_v;  // K is used like V in fprop\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = Gmem_tile_dq<Cta_tile_o>;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store dK.\n    using Gmem_tile_dk = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o, \n                                             fmha::BITS_PER_ELEMENT_B, \n                                             Cta_tile_p::N, //S, \n                                             Cta_tile_p::K, //D, \n                                             2*CHUNKS>;\n\n    // The shared memory tile to swizzle dK.\n    using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;\n    static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);\n    static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);\n\n    // The shared memory tile to reload Q transposed.\n    using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n    // The global memory tile to load dP, stored in S\n    using Gmem_tile_s = Gmem_tile_mma_s<Cta_tile_p>;\n    // The shared memory tile to transpose dP.\n    using Smem_tile_st = Smem_tile_mma_transposed<Cta_tile_p>;  \n\n    using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n    enum { M = Mma_tile_p::MMAS_M };\n    enum { N = Mma_tile_p::MMAS_N };\n    static_assert(M == Mma_tile_o::MMAS_M);\n    static_assert(N == Mma_tile_o::MMAS_K);\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    const int bidc = blockIdx.z;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q (as B).\n    Smem_tile_qt smem_qt(&smem_[0], tidx);\n    // Allocate the global memory tile loader for dP.\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n    // Allocate the shared memory tile loader for dP.\n    Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n    Noloop nl_traits(bidc);\n\n    nl_traits.move_all(gmem_q, gmem_o, gmem_s);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_qt);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n\n    uint4 s_regs[M][N];\n    gmem_s.load(s_regs, mask);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_qt);\n    gmem_k.commit(smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];\n    smem_qt.load(frag_qt[0], 0);\n    typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];\n    smem_k.load(frag_k[0], 0);\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    enum { THREADS_PER_ROW = 32 };\n\n    // Declare the accumulators for the 2nd gemm.\n    fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);\n\n    // Load over the entire sequence length.\n    for(int l=0;l < nl_traits.num_steps_; l++) {\n\n        // Pack dP as Fragment_a\n        fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        #pragma unroll\n        for( int mi = 0; mi < M; mi++ ) {\n            #pragma unroll\n            for( int ni = 0; ni < N; ni++ ) {\n                uint4 &dst = s_regs[mi][ni];\n                frag_p[ni][mi].reg(0) = dst.x;\n                frag_p[ni][mi].reg(1) = dst.z;\n                frag_p[ni][mi].reg(2) = dst.y;\n                frag_p[ni][mi].reg(3) = dst.w;\n            }\n        }\n        smem_s.store(s_regs);\n        if(l < nl_traits.num_steps_- 1) {\n            // Load next part of S\n            gmem_s.move();\n            gmem_s.load(s_regs, mask);\n            // Trigger the load for the next Q values.\n            smem_qt.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_qt);\n        }\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T. dQ = dP x dK\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_k.load(frag_k[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_o::MMAS_K;\n            fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n        }\n\n        static_assert(Gmem_tile_o::LOOPS == 1); //DEBUG\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];\n        smem_s.load(frag_s);\n\n        static_assert(Mma_tile_dk::MMAS_K == 1);  // DEBUG\n\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_qt.load(frag_qt[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_dk::MMAS_K;\n            fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n        }\n\n        // Commit the values for Q into shared memory.\n        if(l < nl_traits.num_steps_- 1) {\n            gmem_q.commit(smem_qt);\n            __syncthreads();\n            // Trigger the loads for the values of Q for the next iteration.\n            smem_qt.load(frag_qt[0], 0);\n            smem_k.load(frag_k[0], 0);\n        }\n\n    }  // Outer loop over the sequence length.\n\n    // Epilogue for dK = dP' * dq. We're fully exposed to this!\n\n    // Epilogue swizzle for dK\n    Smem_tile_dk smem_dk(&smem_[0], tidx);\n    smem_dk.store(acc_dk);\n    \n    __syncthreads();\n    \n    uint4 dk_out[Smem_tile_dk::NUM_LDS];\n    smem_dk.load(dk_out);\n    Qkv_params dk_params;\n    dk_params.qkv_ptr = params.dkv_ptr;\n    dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);\n    dk_params.h = params.h;\n    Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx);\n    gmem_dk.store(dk_out);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_128_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_128_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\nvoid run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_128_64_sm80_train_kernel : &fmha_fprop_fp16_128_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_256_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_256_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\nvoid run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_256_64_sm80_train_kernel : &fmha_fprop_fp16_256_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN_reload_v.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_384_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_384_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\nvoid run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_384_64_sm80_train_kernel : &fmha_fprop_fp16_384_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_v + smem_size_o + smem_size_softmax;\n\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n#include \"fmha_fprop_kernel_1xN_nl.h\"\n\nusing Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_fprop_fp16_512_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, true>(params);\n}\n\nextern \"C\" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN<Kernel_traits, false>(params);\n}\n\ntemplate<int CHUNKS>\n__global__ void fmha_fprop_fp16_512_64_sm80_train_nl_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN_nl<CHUNKS,Kernel_traits, true>(params);\n}\n\ntemplate<int CHUNKS>\n__global__ void fmha_fprop_fp16_512_64_sm80_predict_nl_kernel(Fused_multihead_attention_fprop_params params) {\n    fmha::device_1xN_nl<CHUNKS, Kernel_traits, false>(params);\n}\n\n\nvoid run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel;\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    dim3 grid(params.h, params.b);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n\nvoid run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream) {\n\n    auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2> : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;\n    if( num_chunks == 2 ) {\n        kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2>\n                             : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;\n    } else if( num_chunks == 3 ) {\n        kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<3>\n                             : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<3>;\n    } else if( num_chunks == 4 ) {\n        kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<4>\n                             : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<4>;\n    } else {\n        assert(false && \"Unsupported num_chunks\");\n    }\n\n    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n    constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n    constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);\n    if( smem_size >= 48 * 1024 ) {\n        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n\n    dim3 grid(params.h, params.b, num_chunks);\n    kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    auto seeds = at::cuda::philox::unpack(params.philox_args);\n\n    Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));\n\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for V.\n    Gmem_tile_v gmem_v(params, 2, binfo, tidx);\n    // The base pointer of smem_v;\n    char *smem_v_ = nullptr;\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];\n    } else {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];\n    }\n    // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!\n    Smem_tile_v smem_v(smem_v_, tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n    // Trigger the loads for K.\n    gmem_v.load(smem_v);\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Commit the data for V to shared memory.\n    if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        gmem_v.commit(smem_v);\n    }\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n        smem_k.load(frag_k[ki], ki);\n    }\n\n    // Commit the data for V to shared memory if it has not been done already.\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        // Make sure we are done loading the fragments for K.\n        __syncthreads();\n\n        // Commit the data to shared memory for V.\n        gmem_v.commit(smem_v);\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n    }\n\n    // Load the fragments for V. We keep the data in registers during the entire kernel.\n    typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n        smem_v.load(frag_v[ki], ki);\n    }\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;\n    Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);\n\n    enum { THREADS_PER_ROW = 32 };\n    enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n\n    // Load over the entire sequence length.\n    for( int l = 0; l < STEPS; l++ ) {\n        const int loop = l * Cta_tile_p::M;\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n\n    // Do this part of P^T = (Q * K^T)^T.\n    #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Load the mask for that iteration.\n        mask.load(l);\n\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        // Apply the mask.\n        softmax.apply_mask(mask);\n\n        if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {\n            // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction\n            __syncthreads();\n        }\n        // Compute the max.\n        float p_max[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Max_>(p_max);\n\n        // Make sure we are done reading shared memory.\n        __syncthreads();\n\n        // Compute the exponential value.\n        softmax.apply_exp(p_max);\n\n        // Compute the sum.\n        float p_sum[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        // Finalize softmax on the accumulators of P^T.\n        softmax.scale(p_sum);\n\n        if( Is_training ) {\n            auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < 2; ii++ ) {\n                    #pragma unroll\n                    for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {\n                        float4 tmp = uniform4(ph());\n                        // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from\n                        // pre-existing zeros\n                        softmax.elt_[2 * mi + ii][4 * ni + 0] =\n                            encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 1] =\n                            encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 2] =\n                            encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 3] =\n                            encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);\n                    }\n                }\n            }\n            gmem_s.store(softmax.elt_, mask);\n            gmem_s.move();\n        }\n\n        // Trigger the load for the next Q values.\n        if(l < STEPS - 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n\n        using Frag_p = fmha::Fragment_a< fmha::Row>;\n        Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        softmax.pack(frag_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {\n                    //\"Apply\" the dropout.\n                    frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);\n                    frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T.\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);\n        }\n\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        // Commit the values for Q into shared memory.\n        if(l < STEPS - 1) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n        // Trigger the loads for the values of Q for the next iteration.\n        smem_q.load(frag_q[0], 0);\n\n    }  // Outer loop over the sequence length.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int CHUNKS, typename Kernel_traits, bool Is_training, typename Params>\ninline __device__ void device_1xN_nl(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    // The global memory tile to store S/D.\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    const int bidc = blockIdx.z;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    Noloop nl_traits(bidc);\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    auto seeds = at::cuda::philox::unpack(params.philox_args);\n\n    Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));\n\n    fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[0], tidx);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for V.\n    Gmem_tile_v gmem_v(params, 2, binfo, tidx);\n    // The base pointer of smem_v;\n    char *smem_v_ = nullptr;\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];\n    } else {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];\n    }\n    // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!\n    Smem_tile_v smem_v(smem_v_, tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    nl_traits.move_all(gmem_q, gmem_o, gmem_s);\n\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n    // Trigger the loads for K.\n    gmem_v.load(smem_v);\n\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Commit the data for V to shared memory.\n    if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        gmem_v.commit(smem_v);\n    }\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n    smem_q.load(frag_q[0], 0);\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n        smem_k.load(frag_k[ki], ki);\n    }\n\n    // Commit the data for V to shared memory if it has not been done already.\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        // Make sure we are done loading the fragments for K.\n        __syncthreads();\n\n        // Commit the data to shared memory for V.\n        gmem_v.commit(smem_v);\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n    }\n\n    // Load the fragments for V. We keep the data in registers during the entire kernel.\n    typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n        smem_v.load(frag_v[ki], ki);\n    }\n\n    enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n    Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);\n\n    // The number of threads per row.\n    enum { THREADS_PER_ROW = 32 };\n\n    // Load over the entire sequence length.\n    for(int l = 0; l < nl_traits.num_steps_;l++) {\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n\n        // Do this part of P^T = (Q * K^T)^T.\n        #pragma unroll\n        for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {\n\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[ki & 1], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Do the final stage of math.\n        {\n            int ki = Mma_tile_p::MMAS_K;\n            fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n        }\n\n        // Trigger the load for the next Q values.\n        if( l < nl_traits.num_steps_- 1) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n\n\n\n        // Load the mask for that iteration.\n        mask.load(nl_traits.loop_offset_ + l);\n\n        // Convert from the accumulator type to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        // Apply the mask.\n        softmax.apply_mask(mask);\n\n        if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {\n            // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction\n            __syncthreads();\n        }\n\n        // Compute the max.\n        float p_max[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Max_>(p_max);\n\n        // Make sure we are done reading shared memory.\n        __syncthreads();\n\n        // Compute the exponential value.\n        softmax.apply_exp(p_max);\n\n        // Compute the sum.\n        float p_sum[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        // Finalize softmax on the accumulators of P^T.\n        softmax.scale(p_sum);\n        if( Is_training ) {\n            auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < 2; ii++ ) {\n                    #pragma unroll\n                    for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {\n                        float4 tmp = uniform4(ph());\n                        // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros\n                        softmax.elt_[2 * mi + ii][4 * ni + 0] =\n                            encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 1] =\n                            encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 2] =\n                            encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 3] =\n                            encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);\n                    }\n                }\n            }\n            gmem_s.store(softmax.elt_, mask);\n            gmem_s.move();\n        }\n\n        using Frag_p = fmha::Fragment_a<fmha::Row>;\n        Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        softmax.pack(frag_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {\n                    //\"Apply\" the dropout.\n                    frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);\n                    frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        // Do this part of O = P^T * V^T.\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);\n        }\n\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Make sure the data was read from shared memory.\n            if( ii < Gmem_tile_o::LOOPS - 1 ) {\n                __syncthreads();\n            }\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        // Commit the values for Q into shared memory.\n        if( l < nl_traits.num_steps_- 1) {\n            gmem_q.commit(smem_q);\n            __syncthreads();\n            smem_q.load(frag_q[0], 0);\n        }\n\n    }  // Outer loop over the sequence length.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"fmha_kernel.h\"\n#include <fmha/kernel_traits.h>\n#include <fmha/gemm.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {\n\n    // The description of the CTA tile for the 1st batched GEMM.\n    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n    // The description of the CTA tile for the 2nd batched GEMM.\n    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n\n    // The MMA tile for the 1st GEMM.\n    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n    // The MMA tile for the 2nd GEMM.\n    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n\n    // The global memory tile to load Q.\n    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n    // The shared memory tile to swizzle Q.\n    using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n    // The global memory tile to load K.\n    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n    // The shared memory tile to swizzle K.\n    using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n    // The global memory tile to load V.\n    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n    // The shared memory tile to swizzle V.\n    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n    // The global memory tile to store O.\n    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n    // The shared memory tile to swizzle O.\n    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.x;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n    if( binfo.stop_early() )\n        return;\n\n    Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n    auto seeds = at::cuda::philox::unpack(params.philox_args);\n    Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));\n\n    static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);\n\n    // Allocate the global memory tile loader for K.\n    Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n    // Allocate the shared memory tile loader for K.\n    Smem_tile_k smem_k(&smem_[0], tidx);\n\n    // Allocate the global memory tile loader for V.\n    Gmem_tile_v gmem_v(params, 2, binfo, tidx);\n    // The base pointer of smem_v;\n    char *smem_v_ = nullptr;\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        smem_v_ = &smem_[0];\n    } else {\n        smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];\n    }\n    static_assert(Kernel_traits::SHARE_SMEM_FOR_K_AND_V);\n    static_assert(Smem_tile_k::BYTES_PER_TILE == Smem_tile_v::BYTES_PER_TILE);\n    // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!\n    Smem_tile_v smem_v(smem_v_, tidx);\n\n    // Allocate the global memory tile loader for Q.\n    Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n    // Allocate the shared memory tile loader for Q.\n    Smem_tile_q smem_q(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);\n\n    // Allocate the global memory tile loader for O.\n    Gmem_tile_o gmem_o(params, binfo, tidx);\n    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n    Smem_tile_o smem_o(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);\n\n    // Trigger the loads for Q.\n    gmem_q.load(smem_q);\n    // Trigger the loads for K.\n    gmem_k.load(smem_k);\n    // Trigger the loads for K.\n    gmem_v.load(smem_v);\n\n\n    // Commit the data for Q and K to shared memory.\n    gmem_q.commit(smem_q);\n    gmem_k.commit(smem_k);\n\n    // Commit the data for V to shared memory.\n    if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        gmem_v.commit(smem_v);\n    }\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Load the fragments for Q.\n    typename Smem_tile_q::Fragment frag_q[1][Mma_tile_p::MMAS_M];\n\n    // Load the fragments for K. We keep the data in registers during the entire kernel.\n    typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];\n    #pragma unroll\n    for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n        smem_k.load(frag_k[ki], ki);\n    }\n\n    // Commit the data for V to shared memory if it has not been done already.\n    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {\n        // Make sure we are done loading the fragments for K.\n        __syncthreads();\n\n        // Commit the data to shared memory for V.\n        gmem_v.commit(smem_v);\n\n    }\n\n    enum { BITS_PER_ELT_S = sizeof(typename fmha::A_type) * 8 };\n\n    Gmem_tile_s gmem_s(params.s_ptr, params, tidx);\n\n    // Create the object to do the softmax.\n    using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;\n    Softmax softmax(params, &smem_[Smem_tile_v::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);\n\n    constexpr int SMEM_BYTES_SOFTMAX = Softmax::ELEMENTS * sizeof(float);\n    static_assert(SMEM_BYTES_SOFTMAX == Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float));\n\n    enum { THREADS_PER_ROW = 32 };\n\n    const float pinv = 1.f / params.p_dropout;\n\n    // Load over the entire sequence length.\n    for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {\n        if( loop >= binfo.actual_seqlen )\n            break;\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {\n\n            // Trigger the load from shared memory for the next series of Q values.\n            smem_q.load(frag_q[0], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_p, frag_q[0], frag_k[ki]);\n        }\n\n        // Load the mask for that iteration.\n        mask.load(outer);\n\n        // Convert from the accumulator typ e to FP32 for Softmax.\n        softmax.unpack(acc_p);\n\n        // Apply the mask.\n        softmax.apply_mask(mask);\n\n        static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);\n\n        // Compute the max.\n        float p_max[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Max_>(p_max);\n\n        // Make sure we are done reading shared memory.\n        __syncthreads();\n        // Compute the exponential value.\n        softmax.apply_exp(p_max);\n        // Compute the sum.\n        float p_sum[Mma_tile_p::MMAS_M * 2];\n        softmax.template reduce<fmha::Sum_>(p_sum);\n\n        // Finalize softmax on the accumulators of P^T.\n        softmax.scale(p_sum);\n\n        __syncthreads();\n        if( Is_training ) {\n            auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < 2; ii++ ) {\n                    #pragma unroll\n                    for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {\n                        float4 tmp = uniform4(ph());\n                        // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from\n                        // pre-existing zeros\n                        softmax.elt_[2 * mi + ii][4 * ni + 0] =\n                            encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 1] =\n                            encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 2] =\n                            encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);\n                        softmax.elt_[2 * mi + ii][4 * ni + 3] =\n                            encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);\n                    }\n                }\n            }\n\n            gmem_s.store(softmax.elt_, mask);\n            gmem_s.move();\n        }\n\n        // Trigger the load for the next Q values.\n        if( loop + Cta_tile_p::M < Cta_tile_p::N ) {\n            smem_q.move_to_next_write_buffer();\n            gmem_q.move();\n            gmem_q.load(smem_q);\n        }\n        typename Smem_tile_v::Fragment frag_v[1][Mma_tile_o::MMAS_N];\n\n        using Frag_p = fmha::Fragment_a< fmha::Row>;\n        Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n        softmax.pack(frag_p);\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {\n            #pragma unroll\n            for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {\n                #pragma unroll\n                for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {\n                    //\"Apply\" the dropout.\n                    frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);\n                    frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));\n                }\n            }\n        }\n\n        // Declare the accumulators for the 1st gemm.\n        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n        #pragma unroll\n        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {\n            // Trigger the load from shared memory for the next series of V values.\n            smem_v.load(frag_v[0], ki);\n            // Do the math for the values already in registers.\n            fmha::gemm(acc_o, frag_p[ki], frag_v[0]);\n        }\n\n        // Loop over MMAS_M.\n        #pragma unroll\n        for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {\n\n            // Swizzle the elements and do the final reduction.\n            smem_o.store(acc_o, ii);\n\n            // Make sure the data is in shared memory.\n            __syncthreads();\n\n            // Load from shared memory.\n            uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n            smem_o.load(out);\n\n            // Always sync after last iter: shared smem_q and smem_o!\n            __syncthreads();\n\n            // Output the values.\n            gmem_o.store(out, ii);\n        }\n        // same smem as o\n\n        // Move to the next part of the output.\n        gmem_o.move();\n\n        // Commit the values for Q into shared memory.\n        if( loop + Cta_tile_p::M < Cta_tile_p::N ) {\n            gmem_q.commit(smem_q);\n        }\n\n        // Make sure the data is in shared memory.\n        __syncthreads();\n\n    }  // Outer loop over the sequence length.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_kernel.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <multihead_attn/philox.h>\n\n#include <fmha.h>\n#include <fmha/utils.h>\n#include <fmha/smem_tile.h>\n#include <fmha/gmem_tile.h>\n#include <fmha/mask.h>\n#include <fmha/softmax.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int THREADS_PER_CTA>\nstruct BlockInfoPadded {\n\n    template<typename Params>\n    __device__ BlockInfoPadded(const Params &params,\n                               const int bidb,\n                               const int bidh,\n                               const int tidx)\n        : bidb(bidb), bidh(bidh), h(params.h) {\n\n        // The block index.\n        sum_s = params.cu_seqlens[bidb];\n        actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;\n        bidx = sum_s * params.h + bidh;\n\n        tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;\n    }\n\n    __device__ bool stop_early() const {\n        return actual_seqlen == 0;\n    }\n\n    int actual_seqlen;\n    int bidx;\n    int sum_s;\n    int bidh;\n    int bidb;\n    int tidx_global;\n    int h;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int CHUNKS, typename Cta_tile> \nstruct Noloop_traits{\n    // Interpretation of Cta_tile dims, i.e. Cta_tile_p:\n    enum{ STEP = Cta_tile::M };\n    enum{ SEQLEN = Cta_tile::N };\n\n    // The size of the subsequence this CTA is processing\n    enum { SUBSEQ = SEQLEN / CHUNKS };\n    static_assert(SUBSEQ * CHUNKS == SEQLEN);\n\n    // The number of steps to process the subsequence\n    enum { NUM_STEPS = SUBSEQ / STEP };\n    static_assert(NUM_STEPS  * Cta_tile::M == SUBSEQ);\n\n    inline __device__ Noloop_traits(const int bidc) \n        : loop_offset_(NUM_STEPS * bidc)\n        , bidc_(bidc) {\n    }\n\n    template<typename ... Tiles> \n    inline __device__ void move_all(Tiles & ... tiles) const {\n        using expand_type = int[];\n        for( int s = 0; s < loop_offset_; s++ ) {\n            expand_type{ (tiles.move(), 0)... };\n        }\n    }\n\n    inline __device__ int get_idx_dk() const {\n        //return bidc_;\n        return bidc_ * 2 + 0;\n    }\n\n    inline __device__ int get_idx_dv() const {\n        //return CHUNKS + bidc_;\n        return bidc_ * 2 + 1;\n    }\n\n    inline __device__ int offset_loop_count(const int l) {\n        // convert loop counter to position in the outer sequence\n        return (loop_offset_ + l) * STEP;\n    }\n\n    const int loop_offset_;\n    const uint32_t bidc_;\n    const int num_steps_ = NUM_STEPS;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Cta_tile> \nstruct Noloop_traits<3, Cta_tile>{\n    // Interpretation of Cta_tile dims, i.e. Cta_tile_p:\n    enum{ STEP = Cta_tile::M };\n    enum{ SEQLEN = Cta_tile::N };\n\n    static_assert(STEP == 16 && SEQLEN == 512);\n\n    inline __device__ Noloop_traits(const int bidc)\n        : bidc_(bidc)\n        , num_steps_(bidc < 2 ? 11 : 10) \n        , loop_offset_(bidc * 11) {\n    }\n\n    template<typename ... Tiles> \n    inline __device__ void move_all(Tiles & ... tiles) const {\n        using expand_type = int[];\n        for( int s = 0; s < loop_offset_; s++ ) {\n            expand_type{ (tiles.move(), 0)... };\n        }\n    }\n\n    inline __device__ int get_idx_dk() const {\n        //return bidc_;\n        return bidc_ * 2 + 0;\n    }\n\n    inline __device__ int get_idx_dv() const {\n        //return CHUNKS + bidc_;\n        return bidc_ * 2 + 1;\n    }\n\n    inline __device__ int offset_loop_count(const int l) {\n        // convert loop counter to position in the outer sequence\n        return (loop_offset_ + l) * STEP;\n    }\n\n    const int loop_offset_;\n    const uint32_t bidc_;\n    const int  num_steps_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n\ninline __device__ float4 ldg128(const void *ptr) {\n    return *static_cast<const float4 *>(ptr);\n}\n\ninline __device__ void stg128(void *ptr, const float4 &data) {\n    *static_cast<float4 *>(ptr) = data;\n}\n\ntemplate<typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>\n__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out,\n                                                                     const void *__restrict__ in,\n                                                                     const int *__restrict__ cu_seqlens,\n                                                                     const int batch_size) {\n\n    enum { BYTES_PER_LDG = 16 };\n    enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };\n\n    // One CTA hidden vector for K and V\n    enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };\n    // The stride in bytes in dQKV\n    enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };\n    // The offset in bytes in dQKV to the dKV part for non-interleaved heads\n    enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };\n\n    static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); \n\n    // Size in bytes of the input tile\n    enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };\n\n    enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };\n\n    enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };\n    static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);\n\n    union Vec_t {\n        float4 raw;\n        T elt[NUM_ELTS];\n    };\n\n    // ZERO-OUT invalid positions in dQKV\n    const int total = cu_seqlens[batch_size];\n    if(blockIdx.x >= total){\n        enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };\n        enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };\n\n        const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);\n\n        char *base_ptr = static_cast<char *>(out) + blockIdx.x * OUT_STRIDE_BYTES;\n\n        for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){\n            stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);\n        }\n\n        return;\n    }\n\n    // SETUP\n    const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;\n    const char *ptr_in = static_cast<const char *>(in) + offset_in;\n\n    const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;\n    char *ptr_out = static_cast<char *>(out) + OUT_OFFSET_KV_BYTES + offset_out;\n\n    // LOAD\n\n    Vec_t local_in[CHUNKS][LDGS];\n\n    #pragma unroll\n    for( int c = 0; c < CHUNKS; c++ ) {\n        #pragma unroll\n        for( int l = 0; l < LDGS; l++ ) {\n            int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;\n            local_in[c][l].raw = ldg128(ptr_in + offset);\n        }\n    }\n\n    // UNPACK\n    float acc[LDGS][NUM_ELTS];\n\n    #pragma unroll\n    for( int l = 0; l < LDGS; l++ ) {\n        #pragma unroll\n        for( int e = 0; e < NUM_ELTS; e++ ) {\n            acc[l][e] = float(local_in[0][l].elt[e]);\n        }\n    }\n\n    // COMPUTE\n    #pragma unroll\n    for( int c = 1; c < CHUNKS; c++ ) {\n        #pragma unroll\n        for( int l = 0; l < LDGS; l++ ) {\n            #pragma unroll\n            for( int e = 0; e < NUM_ELTS; e++ ) {\n                acc[l][e] += float(local_in[c][l].elt[e]);\n            }\n        }\n    }\n\n    // PACK\n    Vec_t local_out[LDGS];\n\n    #pragma unroll\n    for( int l = 0; l < LDGS; l++ ) {\n        #pragma unroll\n        for( int e = 0; e < NUM_ELTS; e++ ) {\n            local_out[l].elt[e] = T(acc[l][e]);\n        }\n    }\n\n    // STORE\n    #pragma unroll\n    for( int l = 0; l < LDGS; l++ ) {\n        const int offset = l * BYTES_PER_CTA;\n        stg128(ptr_out + offset, local_out[l].raw);\n    }\n}\n\nvoid fmha_run_noloop_reduce(void *out,\n                            const void *in,\n                            const int *cu_seqlens,\n                            const int hidden_size,\n                            const int batch_size,\n                            const int total,\n                            const int num_chunks,\n                            cudaStream_t stream) {\n\n    const int blocks = total;\n\n    if(hidden_size == 1024){\n\n        constexpr int HIDDEN_SIZE = 1024;\n        constexpr int THREADS = 256;\n\n        if( num_chunks == 2 ) {\n            fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);\n        } else if( num_chunks == 3 ) {\n            fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);\n        } else {\n            assert(false && \"Unsupported num_chunks\");\n        }\n\n    }else{\n        assert(false && \"Unsupported hidden_size\");\n    }\n\n    FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/fmha/src/fmha_utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n * \n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime_api.h>\n#include <cuda_fp16.h>\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define FMHA_CHECK_CUDA( call )                                                                    \\\n    do {                                                                                           \\\n        cudaError_t status_ = call;                                                                \\\n        if( status_ != cudaSuccess ) {                                                             \\\n            fprintf( stderr,                                                                       \\\n                     \"CUDA error (%s:%d): %s\\n\",                                                   \\\n                     __FILE__,                                                                     \\\n                     __LINE__,                                                                     \\\n                     cudaGetErrorString( status_ ) );                                              \\\n            exit( 1 );                                                                             \\\n        }                                                                                          \\\n    } while( 0 )\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nenum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {\n    if( dtype == DATA_TYPE_FP16 ) {\n        half x = __float2half_rn( norm );\n        uint16_t h = reinterpret_cast<const uint16_t &>( x );\n        ushort2 h2 = { h, h };\n        alpha = reinterpret_cast<const uint32_t &>( h2 );\n    } else if( dtype == DATA_TYPE_FP32 ) {\n        alpha = reinterpret_cast<const uint32_t &>( norm );\n    } else if( dtype == DATA_TYPE_INT32 ) {\n        int32_t inorm = static_cast<int32_t>( norm );\n        alpha = reinterpret_cast<const uint32_t &>( inorm );\n    } else {\n        assert( false );\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {\n    switch( dtype ) {\n    case DATA_TYPE_FP32:\n        return n * 4;\n    case DATA_TYPE_FP16:\n        return n * 2;\n    case DATA_TYPE_INT32:\n        return n * 4;\n    case DATA_TYPE_INT8:\n        return n;\n    default:\n        assert( false );\n        return 0;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/batch_norm.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCNumerics.cuh>\n\n#include \"THC/THC.h\"\n\n#include \"batch_norm.h\"\n\n#include <cuda.h>\n\n#include \"compat.h\"\n\n#define cudaCheckErrors(msg) \\\n    do { \\\n        cudaError_t __err = cudaGetLastError(); \\\n        if (__err != cudaSuccess) { \\\n            fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", \\\n                msg, cudaGetErrorString(__err), \\\n                __FILE__, __LINE__); \\\n            fprintf(stderr, \"*** FAILED - ABORTING\\n\"); \\\n            exit(1); \\\n        } \\\n    } while (0)\n\nstatic size_t round_up_to_multiple(size_t x, int multiple) {\n  return ((x + multiple - 1) / multiple) * multiple;\n}\n\n// TODO: Stop manually allocating CUDA memory; allocate an ATen byte\n// tensor instead.\nstruct Workspace {\n  Workspace(size_t size) : size(size), data(NULL) {\n    data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);\n  }\n  Workspace(const Workspace&) = delete;\n  Workspace(Workspace&&) = default;\n  Workspace& operator=(Workspace&&) = default;\n  ~Workspace() {\n    if (data) {\n      THCudaFree(at::globalContext().lazyInitCUDA(), data);\n    }\n  }\n\n  size_t size;\n  void* data;\n};\n\n// Return {y}\nat::Tensor nhwc_bn_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void * my_data,\n                       void * pair_data,\n                       void * pair_data2,\n                       void * pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNorm *bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return y;\n}\n\nat::Tensor nhwc_bn_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNorm *bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwdInference(stream, fuse_relu);\n\n  return y;\n\n}\n\nstd::vector<at::Tensor> nhwc_bn_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void * my_data,\n                       void * pair_data, \n                       void * pair_data2, \n                       void * pair_data3, \n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n  // shape\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // outputs\n  at::Tensor x_grad, scale_grad, bias_grad;\n\n  // Allocate outputs\n  x_grad = at::empty_like(x);\n  scale_grad = at::empty_like(scale);\n  bias_grad = at::empty_like(bias);\n\n  // Create wrapper\n  NhwcBatchNorm *bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             x_grad.DATA_PTR<at::Half>(),\n                             nullptr,\n                             dy.DATA_PTR<at::Half>());\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};\n}\n\nint nhwc_bn_fwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n\n    //max occupancy supported by the code is 2\n    return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2);\n}\n\nint nhwc_bn_bwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n    \n    //max occupancy supported by the code is 2\n    return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2);\n}\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/batch_norm.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm.h\n * \\brief CUDA NHWC Batch Normalization code\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer\n*/\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n\n#include <cudnn.h>\n\n#include <algorithm>\n#include <vector>\n#include <string>\n\n#include \"nhwc_batch_norm_kernel.h\"\n#include \"cuda_utils.h\"\n\n\n#define VERBOSE_DEFAULT false\n\nclass NhwcBatchNorm {\n public:\n  NhwcBatchNorm() {\n    name_ = \"nhwc_batchnorm\";\n    createTensorDescriptor(&X_tensor_desc_);\n    createTensorDescriptor(&Y_tensor_desc_);\n  }\n\n  ~NhwcBatchNorm() {\n    destroyTensorDescriptor(X_tensor_desc_);\n    destroyTensorDescriptor(Y_tensor_desc_);\n  }\n\n  void die() {\n    std::cerr << \"batchnorm not initialized\" << std::endl;\n    exit(-1);\n  }\n\n  void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void fwdInference(cudaStream_t stream, bool use_relu);\n  dim3 calc_fwd_grid(int *loop, const int grid_dim_x);\n  dim3 calc_bwd_grid(int *loop, const int grid_dim_x);\n\n  void setInputDescriptor(const cudnnTensorFormat_t format,\n                                  const cudnnDataType_t     data_type,\n                                  int n, int c, int h, int w, int bn_group) {\n    m_ = n * h * w;\n    int m_bn_adjusted = m_ * bn_group;\n    c_ = c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    svar_inv_count_ = 1.f / m_bn_adjusted;\n    // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).\n    int divisor = m_bn_adjusted - 1;\n    // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.\n    rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;\n    setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  void setOutputDescriptor(const cudnnTensorFormat_t format,\n                                   const cudnnDataType_t     data_type,\n                                   int n, int c, int h, int w) {\n    setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  const std::vector<size_t> numWorkspaceBytes() const;\n\n  void setWorkspacePointers(\n      const std::vector<void*>&  workspace,\n      const std::vector<size_t>& num_workspace_bytes);\n\n  void setInputOutputPointers(void* X, void* dX, void* Y, void *dY) {\n    X_ = X;\n    dX_  = dX;\n    Y_   = Y;\n    dY_  = dY;\n  }\n\n  // Sets the pointers for the scale and weight (in that order) data and derivative buffers.\n  void setWeightPointers(const std::vector<void*>& weight_pointers,\n                                 const std::vector<void*>& deriv_pointers) {\n    assert(weight_pointers.size() == 2);\n    assert(deriv_pointers.size()  == 2);\n    scale_  = static_cast<float*>(weight_pointers[0]);\n    bias_   = static_cast<float*>(weight_pointers[1]);\n    dscale_ = static_cast<float*>(deriv_pointers[0]);\n    dbias_  = static_cast<float*>(deriv_pointers[1]);\n  }\n\n  // Sets the pointers for the population mean and variance buffers, in that order.\n  void setParameterPointers(const std::vector<void*>& param_pointers) {\n    assert(param_pointers.size() == 2);\n    population_mean_     = static_cast<float*>(param_pointers[0]);\n    population_variance_ = static_cast<float*>(param_pointers[1]);\n  }\n\n  void setConstants(const double exp_avg_factor, const double eps) {\n    exp_avg_factor_ = exp_avg_factor;\n    eps_ = eps;\n  }\n\n  void processCudnnStatus(const cudnnStatus_t& status,\n                          const std::string& string = std::string(),\n                          bool verbose = VERBOSE_DEFAULT) {\n    if (status != CUDNN_STATUS_SUCCESS)\n      LOG(FATAL) << string << \" \" << cudnnGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudnnGetErrorString(status);\n  }\n\n  void checkCudaStatus(const std::string& string = std::string(),\n                       bool verbose = VERBOSE_DEFAULT) {\n    cudaError_t status = cudaGetLastError();\n    if (status != cudaSuccess)\n      LOG(FATAL) << string << \" \" << cudaGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudaGetErrorString(status);\n  }\n\n  size_t size_retired_ctas(int grid_y) const {\n    // Note that the value of max_grid_y to handle known GPUs is about 160.\n    const int max_grid_y = 1024;\n    if (grid_y > max_grid_y)\n      LOG(INFO) << \"GPU capabilities exceeds assumptions.\";\n    const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);\n    // Since the region will be initialized once and used for many kernels,\n    // the idea is to return an ample size that will cover all uses.\n    return retired_cta_bytes;\n  }\n\n  cudnnTensorDescriptor_t  X_tensor_desc_ = nullptr;\n  cudnnTensorDescriptor_t  Y_tensor_desc_ = nullptr;\n\n  void*  X_ = nullptr;\n  void* dX_ = nullptr;\n  void*  Y_ = nullptr;\n  void* dY_ = nullptr;\n\n  // Learned scale and bias weights.\n  float* scale_  = nullptr;\n  float* dscale_ = nullptr;\n  float* bias_   = nullptr;\n  float* dbias_  = nullptr;\n\n  // Computed population mean and variance parameters.\n  float* population_mean_     = nullptr;\n  float* population_variance_ = nullptr;\n\n  // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).\n  float* minibatch_mean_     = nullptr;\n  float* minibatch_variance_ = nullptr;\n\n  int m_ = 0;  // Number of values per channel that BN is normalizing.\n  int c_ = 0;  // Number of channels over which BN is normalizing.\n\n  float svar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get saved variance\n  float rvar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get running variance\n\n  double exp_avg_factor_ = 0.;\n  double eps_            = 0.;\n  std::string name_;\n\n private:\n  void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,\n                           cudnnTensorFormat_t format,\n                           cudnnDataType_t     data_type,\n                           int n, int c, int h, int w) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);\n    processCudnnStatus(status, \"set tensor descriptor\");\n  }\n\n  void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnCreateTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"create tensor_descriptor\");\n  }\n\n  void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnDestroyTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"destroy tensor_descriptor\");\n  }\n\n protected:\n  float *partial_sums_ = nullptr;\n  int *partial_counts_ = nullptr;\n  int *retired_ctas_   = nullptr;\n\n  void _setFwdParams(NhwcBatchNormFwdParams *params) const;\n  void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;\n  void _setBwdParams(NhwcBatchNormBwdParams *params) const;\n\n  // @todo: ability to configure these?\n  // Kernel params\n  static const int USE_ONLINE_APPROACH = 1;\n  static const int THREADS_PER_CTA = 512;\n  static const int THREADS_PER_PIXEL = 16;\n  static const int C_ELEMENTS_PER_CTA = 64;\n  static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;\n  static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;\n\n  typedef uint16_t StorageType;\n  //typedef float StorageType;\n  // increasing this to 6 causes spills in fwd kernel!\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;\n  static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;\n  static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;\n\n  static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_FWD;\n  static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_BWD;\n  static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;\n\n  // Derived params\n  static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*sizeof(StorageType);\n  static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*2*sizeof(StorageType);\n  static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD;\n  static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_BWD;\n  static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD_INFERENCE;\n\n  // max grid.y in case of group bn is limited by exchange buffer size\n  static const int MAX_GBN_BLOCK_Y = 256;\n\n  // Helper function to launch the forward kernel.\n\n  // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel\n  // version that was compiled with that occupancy in its launch bounds.  This way, we avoid\n  // needless register spills.\n  void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,\n                                dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {\n\n#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\"; \\\n        auto fwd_func = nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" fwd ser coop kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" fwd ser coop kernel\"); \\\n    } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1 && use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, true, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, true, false, 1, coop);\n    } else if (outer_loops == 1 && !use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, false, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, false, false, 1, coop);\n    } else if (use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, true, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, true, false, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, false, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, false, false, 1, coop);\n    }\n#undef LAUNCH_FWD_KERNEL\n  }\n\n  // Helper function to launch the backward kernel.\n\n  void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,\n                                dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {\n#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\"; \\\n        auto bwd_func = nhwc_batch_norm_bwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" bwd coop serial kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<BWD_FUNC>(bwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" bwd coop serial kernel\"); \\\n    } while (0)\n\n#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\"; \\\n        auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" bwd-relu coop serial kernel\"); \\\n    } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1 && use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_RELU_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_RELU_KERNEL(1, 1, coop);\n    } else if (outer_loops == 1 && !use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_KERNEL(1, 1, coop);\n    } else if (use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_RELU_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_RELU_KERNEL(0, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_BWD_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_KERNEL(0, 1, coop);\n    }\n#undef LAUNCH_BWD_KERNEL\n  }\n\n public:\n\n  // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n\n  // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n};\n\nconst std::vector<size_t> NhwcBatchNorm::numWorkspaceBytes() const {\n  assert(c_ > 0);\n\n  // choose the max memory required between fwd/bwd passes\n  int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);\n  int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);\n  int grid_x = max(grid_x_fwd, grid_x_bwd);\n  int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  const size_t num_mean_bytes     = c_ * sizeof(float);\n  const size_t num_variance_bytes = num_mean_bytes;\n  const size_t size_sums          = grid_y*grid_x*THREADS_PER_PIXEL*\\\n      ELEMENTS_PER_LDG*2*sizeof(float);\n  const size_t size_counts        = grid_y*grid_x*sizeof(int);\n\n  return {num_mean_bytes, num_variance_bytes,\n          size_retired_ctas(grid_y), size_sums, size_counts};\n}\n\nvoid NhwcBatchNorm::setWorkspacePointers(\n      const std::vector<void*>& workspace,\n      const std::vector<size_t>& num_workspace_bytes) {\n  assert(workspace.size() == 5);\n  assert(num_workspace_bytes.size() == 5);\n\n  minibatch_mean_     = static_cast<float*>(workspace[0]);\n  minibatch_variance_ = static_cast<float*>(workspace[1]);\n  retired_ctas_       = static_cast<int*>(workspace[2]);\n  partial_sums_       = static_cast<float*>(workspace[3]);\n  partial_counts_     = static_cast<int*>(workspace[4]);\n}\n\nvoid NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dst          = static_cast<uint16_t*>(Y_);\n  params->gmem_src1         = nullptr;\n  params->gmem_bias         = bias_;\n  params->gmem_scale        = scale_;\n  params->gmem_running_mean = population_mean_;\n  params->gmem_running_var  = population_variance_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->gmem_relu_bitmask = nullptr;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->rvar_inv_count    = rvar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_counts       = partial_counts_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->var_eps           = eps_;\n  params->outer_loops       = 0;\n  params->exp_avg_factor    = static_cast<float>(exp_avg_factor_);\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams\n                                                        *params) const {\n  params->gmem_src   = static_cast<uint16_t*>(X_);\n  params->gmem_dst   = static_cast<uint16_t*>(Y_);\n  params->gmem_src1  = nullptr;\n  params->gmem_bias  = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_mean  = population_mean_;\n  params->gmem_var   = population_variance_;\n  params->nhw        = m_;\n  params->c          = c_;\n  params->var_eps    = eps_;\n}\n\nvoid NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dy           = static_cast<uint16_t*>(dY_);\n  params->gmem_dst          = static_cast<uint16_t*>(dX_);\n  params->gmem_dst1         = nullptr;\n  params->gmem_relu_bitmask = nullptr;\n  params->gmem_dscale       = dscale_;\n  params->gmem_dbias        = dbias_;\n  params->gmem_scale        = scale_;\n  params->gmem_bias         = bias_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->outer_loops       = 0;\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      //      && minibatch_mean_ != nullptr\n      //      && minibatch_variance_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);\n  grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  // @todo: maybe just move this inside initialize routine?\n  NhwcBatchNormFwdInferenceParams params;\n  _setFwdInferenceParams(&params);\n\n  if (use_relu) {\n    nhwc_batch_norm_fwd_inference\n      <StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, true, false>\n    <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n    checkCudaStatus(name_ + \" fwd_inference-relu kernel\");\n  } else {\n    nhwc_batch_norm_fwd_inference\n      <StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, false>\n    <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n    checkCudaStatus(name_ + \" fwd_inference kernel\");\n  }\n}\n\ndim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\ndim3 NhwcBatchNorm::calc_bwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\nvoid NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                        const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr\n      && retired_ctas_   != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormFwdParams params;\n  _setFwdParams(&params);\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n\n  dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);\n  _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);\n}\n\nvoid NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, \n                          const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && (bias_ != nullptr || !use_relu)\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      //      && population_mean_ != nullptr\n      //      && population_variance_ != nullptr\n      && X_ != nullptr\n      && dX_ != nullptr\n      //      && Y_ != nullptr\n      && dY_ != nullptr\n      && dscale_ != nullptr\n      && dbias_ != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormBwdParams params;\n  _setBwdParams(&params);\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n  params.wgrad_coeff = 1.0 / bn_group;\n\n  dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);\n  _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCNumerics.cuh>\n\n#include \"THC/THC.h\"\n\n#include \"batch_norm_add_relu.h\"\n\n#include <cuda.h>\n\n#include \"compat.h\"\n\n//FIXME move the common stuff to common h file\n#define cudaCheckErrors(msg) \\\n    do { \\\n        cudaError_t __err = cudaGetLastError(); \\\n        if (__err != cudaSuccess) { \\\n            fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", \\\n                msg, cudaGetErrorString(__err), \\\n                __FILE__, __LINE__); \\\n            fprintf(stderr, \"*** FAILED - ABORTING\\n\"); \\\n            exit(1); \\\n        } \\\n    } while (0)\n\nstatic size_t round_up_to_multiple(size_t x, int multiple) {\n  return ((x + multiple - 1) / multiple) * multiple;\n}\n\n// TODO: Stop manually allocating CUDA memory; allocate an ATen byte\n// tensor instead.\nstruct Workspace {\n  Workspace(size_t size) : size(size), data(NULL) {\n    data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);\n  }\n  Workspace(const Workspace&) = delete;\n  Workspace(Workspace&&) = default;\n  Workspace& operator=(Workspace&&) = default;\n  ~Workspace() {\n    if (data) {\n      THCudaFree(at::globalContext().lazyInitCUDA(), data);\n    }\n  }\n\n  size_t size;\n  void* data;\n};\n\n// Return {y}\nat::Tensor nhwc_bn_addrelu_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void * my_data,\n                       void * pair_data,\n                       void * pair_data2,\n                       void * pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr,\n                             z.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n  workspace.push_back(bitmask.DATA_PTR<int32_t>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return y;\n}\n\nat::Tensor nhwc_bn_addrelu_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon) {\n\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             nullptr,\n                             y.DATA_PTR<at::Half>(),\n                             nullptr,\n                             z.DATA_PTR<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwdInference(stream);\n\n  return y;\n\n}\n\nstd::vector<at::Tensor> nhwc_bn_addrelu_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void * my_data,\n                       void * pair_data, \n                       void * pair_data2, \n                       void * pair_data3, \n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop) {\n  // shape\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.DATA_PTR<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // outputs\n  at::Tensor x_grad, z_grad, scale_grad, bias_grad;\n\n  // Allocate outputs\n  x_grad = at::empty_like(x);\n  z_grad = at::empty_like(x);\n  scale_grad = at::empty_like(scale);\n  bias_grad = at::empty_like(bias);\n\n  // Create wrapper\n  NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),\n                             x_grad.DATA_PTR<at::Half>(),\n                             nullptr,\n                             dy.DATA_PTR<at::Half>(),\n                             nullptr,\n                             z_grad.DATA_PTR<at::Half>());\n\n  bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});\n  bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void *> workspace;\n  workspace.push_back(minibatch_mean.DATA_PTR<float>());\n  workspace.push_back(minibatch_inv_var.DATA_PTR<float>());\n  workspace.push_back(bitmask.DATA_PTR<int32_t>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();\n  assert(ret_cta.size(0)>=retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad};\n}\n\nint nhwc_bn_addrelu_fwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n    \n    //max occupancy supported by the code is 2\n    return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2);\n}\n\nint nhwc_bn_addrelu_bwd_occupancy() {\n    int device_id=-1;\n    cudaGetDevice(&device_id);\n\n    //max occupancy supported by the code is 2\n    return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2);\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/batch_norm_add_relu.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm_add_relu.h\n * \\brief CUDA NHWC Batch Normalization code with fused addition\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer\n*/\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n\n#include <cudnn.h>\n\n#include <algorithm>\n#include <vector>\n#include <string>\n\n#include \"nhwc_batch_norm_kernel.h\"\n#include \"cuda_utils.h\"\n\n\n#define VERBOSE_DEFAULT false\n\nclass NhwcBatchNormAddRelu {\n public:\n  NhwcBatchNormAddRelu() {\n    name_ = \"nhwc_batchnormaddrelu\";\n    createTensorDescriptor(&X_tensor_desc_);\n    createTensorDescriptor(&Y_tensor_desc_);\n  }\n\n  ~NhwcBatchNormAddRelu() {\n    destroyTensorDescriptor(X_tensor_desc_);\n    destroyTensorDescriptor(Y_tensor_desc_);\n  }\n\n  void die() {\n    std::cerr << \"batchnormaddrelu not initialized\" << std::endl;\n    exit(-1);\n  }\n\n  void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void fwdInference(cudaStream_t stream);\n  dim3 calc_fwd_grid(int *loop, const int grid_dim_x);\n  dim3 calc_bwd_grid(int *loop, const int grid_dim_x);\n\n  void setInputDescriptor(const cudnnTensorFormat_t format,\n                                  const cudnnDataType_t     data_type,\n                                  int n, int c, int h, int w, int bn_group) {\n    m_ = n * h * w;\n    int m_bn_adjusted = m_ * bn_group;\n    c_ = c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    svar_inv_count_ = 1.f / m_bn_adjusted;\n    // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).\n    int divisor = m_bn_adjusted - 1;\n    // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.\n    rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;\n    setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  void setOutputDescriptor(const cudnnTensorFormat_t format,\n                                   const cudnnDataType_t     data_type,\n                                   int n, int c, int h, int w) {\n    setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  const std::vector<size_t> numWorkspaceBytes() const;\n\n  void setWorkspacePointers(\n      const std::vector<void*>&  workspace,\n      const std::vector<size_t>& num_workspace_bytes);\n\n  void setInputOutputPointers(void* X, void* dX, void* Y, void *dY, void* addend, void* dAddend) {\n    X_ = X;\n    dX_  = dX;\n    Y_   = Y;\n    dY_  = dY;\n    addend_   = addend;\n    dAddend_  = dAddend;\n  }\n\n  // Sets the pointers for the scale and weight (in that order) data and derivative buffers.\n  void setWeightPointers(const std::vector<void*>& weight_pointers,\n                                 const std::vector<void*>& deriv_pointers) {\n    assert(weight_pointers.size() == 2);\n    assert(deriv_pointers.size()  == 2);\n    scale_  = static_cast<float*>(weight_pointers[0]);\n    bias_   = static_cast<float*>(weight_pointers[1]);\n    dscale_ = static_cast<float*>(deriv_pointers[0]);\n    dbias_  = static_cast<float*>(deriv_pointers[1]);\n  }\n\n  // Sets the pointers for the population mean and variance buffers, in that order.\n  void setParameterPointers(const std::vector<void*>& param_pointers) {\n    assert(param_pointers.size() == 2);\n    population_mean_     = static_cast<float*>(param_pointers[0]);\n    population_variance_ = static_cast<float*>(param_pointers[1]);\n  }\n\n  void setConstants(const double exp_avg_factor, const double eps) {\n    exp_avg_factor_ = exp_avg_factor;\n    eps_ = eps;\n  }\n\n  void processCudnnStatus(const cudnnStatus_t& status,\n                          const std::string& string = std::string(),\n                          bool verbose = VERBOSE_DEFAULT) {\n    if (status != CUDNN_STATUS_SUCCESS)\n      LOG(FATAL) << string << \" \" << cudnnGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudnnGetErrorString(status);\n  }\n\n  void checkCudaStatus(const std::string& string = std::string(),\n                       bool verbose = VERBOSE_DEFAULT) {\n    cudaError_t status = cudaGetLastError();\n    if (status != cudaSuccess)\n      LOG(FATAL) << string << \" \" << cudaGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudaGetErrorString(status);\n  }\n\n  size_t size_retired_ctas(int grid_y) const {\n    // Note that the value of max_grid_y to handle known GPUs is about 160.\n    const int max_grid_y = 1024;\n    if (grid_y > max_grid_y)\n      LOG(INFO) << \"GPU capabilities exceeds assumptions.\";\n    const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);\n    // Since the region will be initialized once and used for many kernels,\n    // the idea is to return an ample size that will cover all uses.\n    return retired_cta_bytes;\n  }\n\n  cudnnTensorDescriptor_t  X_tensor_desc_ = nullptr;\n  cudnnTensorDescriptor_t  Y_tensor_desc_ = nullptr;\n\n  void*  X_ = nullptr;\n  void* dX_ = nullptr;\n  void*  Y_ = nullptr;\n  void* dY_ = nullptr;\n  void*  addend_ = nullptr;\n  void* dAddend_ = nullptr;\n\n  // Learned scale and bias weights.\n  float* scale_  = nullptr;\n  float* dscale_ = nullptr;\n  float* bias_   = nullptr;\n  float* dbias_  = nullptr;\n\n  // Computed population mean and variance parameters.\n  float* population_mean_     = nullptr;\n  float* population_variance_ = nullptr;\n\n  // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).\n  float* minibatch_mean_     = nullptr;\n  float* minibatch_variance_ = nullptr;\n\n  int m_ = 0;  // Number of values per channel that BN is normalizing.\n  int c_ = 0;  // Number of channels over which BN is normalizing.\n\n  float svar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get saved variance\n  float rvar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get running variance\n\n  double exp_avg_factor_ = 0.;\n  double eps_            = 0.;\n  std::string name_;\n\n private:\n  void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,\n                           cudnnTensorFormat_t format,\n                           cudnnDataType_t     data_type,\n                           int n, int c, int h, int w) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);\n    processCudnnStatus(status, \"set tensor descriptor\");\n  }\n\n  void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnCreateTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"create tensor_descriptor\");\n  }\n\n  void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnDestroyTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"destroy tensor_descriptor\");\n  }\n\n protected:\n  float *partial_sums_ = nullptr;\n  int *partial_counts_ = nullptr;\n  int *retired_ctas_   = nullptr;\n  unsigned int *relu_bitmask_ = nullptr;\n\n  void _setFwdParams(NhwcBatchNormFwdParams *params) const;\n  void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;\n  void _setBwdParams(NhwcBatchNormBwdParams *params) const;\n\n  // @todo: ability to configure these?\n  // Kernel params\n  static const int USE_ONLINE_APPROACH = 1;\n  static const int THREADS_PER_CTA = 512;\n  static const int THREADS_PER_PIXEL = 16;\n  static const int C_ELEMENTS_PER_CTA = 64;\n  static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;\n  static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;\n\n  typedef uint16_t StorageType;\n  // increasing this to 6 causes spills in fwd kernel!\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;\n  static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;\n  static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;\n\n  static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_FWD;\n  static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \\\n      PIXELS_PER_THREAD_IN_SMEM_BWD;\n  static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;\n\n  // Derived params\n  static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*sizeof(StorageType);\n  static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\\\n      ELEMENTS_PER_LDG*2*sizeof(StorageType);\n  static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD;\n  static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_BWD;\n  static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \\\n      PIXELS_PER_THREAD_FWD_INFERENCE;\n\n  // max grid.y in case of group bn is limited by exchange buffer size\n  static const int MAX_GBN_BLOCK_Y = 256;\n\n  // Helper function to launch the forward kernel.\n\n  // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel\n  // version that was compiled with that occupancy in its launch bounds.  This way, we avoid\n  // needless register spills.\n  void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,\n                                dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {\n#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \\\n            \"Nhwc batchnormaddrelu kernel smem too big.\"; \\\n        auto fwd_func = nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \" fwd ser coop kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_FWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_FWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        USE_RELU, \\\n                        USE_ADD_RELU, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<FWD_FUNC>(fwd_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_FWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" fwd ser coop kernel\"); \\\n    } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, false, true, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, false, true, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, false, true, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, false, true, 1, coop);\n    }\n#undef LAUNCH_FWD_KERNEL\n  }\n\n  // Helper function to launch the backward kernel.\n\n  void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,\n                                dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {\n#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \\\n    do { \\\n        CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \\\n            \"Nhwc batchnormaddrelu kernel smem too big.\"; \\\n        auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>; \\\n        if (COMPILED_FOR_OCCUPANCY > 1) { \\\n            cudaFuncSetAttribute(bwd_add_relu_func, \\\n                             cudaFuncAttributePreferredSharedMemoryCarveout, 100); \\\n            checkCudaStatus(name_ + \\\n                \" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)\"); \\\n        } \\\n        void *params_ptr = static_cast<void*>(&params); \\\n        using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \\\n                        StorageType, \\\n                        THREADS_PER_CTA, \\\n                        THREADS_PER_PIXEL, \\\n                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, \\\n                        PIXELS_PER_THREAD_IN_SMEM_BWD, \\\n                        ELEMENTS_PER_LDG, \\\n                        USE_ONLINE_APPROACH, \\\n                        OUTER_LOOPS, \\\n                        COMPILED_FOR_OCCUPANCY>); \\\n        if (COOP) { \\\n            cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } else { \\\n            cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \\\n                grid_dim, \\\n                THREADS_PER_CTA, \\\n                &params_ptr, \\\n                SMEM_SIZE_BWD, \\\n                stream); \\\n        } \\\n        checkCudaStatus(name_ + \" bwd-add-relu coop serial kernel\"); \\\n  } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop);\n    }\n#undef LAUNCH_BWD_KERNEL\n  }\n\n public:\n  // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n\n  // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);\n    int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n};\n\nconst std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {\n  assert(c_ > 0);\n\n  // choose the max memory required between fwd/bwd passes\n  int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);\n  int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);\n  int grid_x = max(grid_x_fwd, grid_x_bwd);\n  int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  const size_t num_mean_bytes     = c_ * sizeof(float);\n  const size_t num_variance_bytes = num_mean_bytes;\n\n  int elems_per_group = ((m_ + 31) & ~31) * 2;\n  int group_count = div_up(c_, C_ELEMENTS_PER_CTA);\n  const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int);\n\n  const size_t size_sums          = grid_y*grid_x*THREADS_PER_PIXEL*\\\n      ELEMENTS_PER_LDG*2*sizeof(float);\n  const size_t size_counts        = grid_y*grid_x*sizeof(int);\n\n  return {num_mean_bytes, num_variance_bytes, bitmask_bytes,\n          size_retired_ctas(grid_y), size_sums, size_counts};\n}\n\nvoid NhwcBatchNormAddRelu::setWorkspacePointers(\n      const std::vector<void*>& workspace,\n      const std::vector<size_t>& num_workspace_bytes) {\n  assert(workspace.size() == 6);\n  assert(num_workspace_bytes.size() == 6);\n\n  minibatch_mean_     = static_cast<float*>(workspace[0]);\n  minibatch_variance_ = static_cast<float*>(workspace[1]);\n  relu_bitmask_       = static_cast<unsigned int*>(workspace[2]);\n  retired_ctas_       = static_cast<int*>(workspace[3]);\n  partial_sums_       = static_cast<float*>(workspace[4]);\n  partial_counts_     = static_cast<int*>(workspace[5]);\n}\n\nvoid NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dst          = static_cast<uint16_t*>(Y_);\n  params->gmem_src1         = static_cast<uint16_t*>(addend_);\n  params->gmem_bias         = bias_;\n  params->gmem_scale        = scale_;\n  params->gmem_running_mean = population_mean_;\n  params->gmem_running_var  = population_variance_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->gmem_relu_bitmask = relu_bitmask_;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->rvar_inv_count    = rvar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_counts       = partial_counts_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->var_eps           = eps_;\n  params->outer_loops       = 0;\n  params->exp_avg_factor    = static_cast<float>(exp_avg_factor_);\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams\n                                                        *params) const {\n  params->gmem_src   = static_cast<uint16_t*>(X_);\n  params->gmem_dst   = static_cast<uint16_t*>(Y_);\n  params->gmem_src1  = static_cast<uint16_t*>(addend_);\n  params->gmem_bias  = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_mean  = population_mean_;\n  params->gmem_var   = population_variance_;\n  params->nhw        = m_;\n  params->c          = c_;\n  params->var_eps    = eps_;\n}\n\nvoid NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams *params) const {\n  params->gmem_src          = static_cast<uint16_t*>(X_);\n  params->gmem_dy           = static_cast<uint16_t*>(dY_);\n  params->gmem_dst          = static_cast<uint16_t*>(dX_);\n  params->gmem_dst1         = static_cast<uint16_t*>(dAddend_);\n  params->gmem_relu_bitmask = relu_bitmask_;\n  params->gmem_dscale       = dscale_;\n  params->gmem_dbias        = dbias_;\n  params->gmem_scale        = scale_;\n  params->gmem_bias         = bias_;\n  params->gmem_saved_mean   = minibatch_mean_;\n  params->gmem_saved_var    = minibatch_variance_;\n  params->nhw               = m_;\n  params->c                 = c_;\n  params->svar_inv_count    = svar_inv_count_;\n  params->gmem_sums         = partial_sums_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->outer_loops       = 0;\n  params->c_blks            = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      //      && minibatch_mean_ != nullptr\n      //      && minibatch_variance_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      && addend_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);\n  grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  // @todo: maybe just move this inside initialize routine?\n  NhwcBatchNormFwdInferenceParams params;\n  _setFwdInferenceParams(&params);\n\n  nhwc_batch_norm_fwd_inference\n    <StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, true>\n  <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n  checkCudaStatus(name_ + \" fwd_inference-relu kernel\");\n}\n\ndim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\ndim3 NhwcBatchNormAddRelu::calc_bwd_grid(int *loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\nvoid NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                               const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      && relu_bitmask_ != nullptr\n      && population_mean_ != nullptr\n      && population_variance_ != nullptr\n      && X_ != nullptr\n      //      && dX_ != nullptr\n      && Y_ != nullptr\n      && addend_ != nullptr\n      //      && dY_ != nullptr\n      //      && dscale_ != nullptr\n      //      && dbias_ != nullptr\n      && partial_sums_   != nullptr\n      && partial_counts_ != nullptr\n      && retired_ctas_   != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormFwdParams params;\n  _setFwdParams(&params);\n\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n\n  dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);\n  _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);\n}\n\nvoid NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                                 const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set =\n      X_tensor_desc_ != nullptr\n      && Y_tensor_desc_ != nullptr\n      && scale_ != nullptr\n      && bias_ != nullptr\n      && minibatch_mean_ != nullptr\n      && minibatch_variance_ != nullptr\n      && relu_bitmask_ != nullptr\n      //      && population_mean_ != nullptr\n      //      && population_variance_ != nullptr\n      && X_ != nullptr\n      && dX_ != nullptr\n      //      && Y_ != nullptr\n      && dY_ != nullptr\n      && dAddend_ != nullptr\n      && dscale_ != nullptr\n      && dbias_ != nullptr\n      && retired_ctas_   != nullptr;\n\n  if (!ptrs_are_set)\n    die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormBwdParams params;\n  _setBwdParams(&params);\n\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group==8)?3:(bn_group >> 1);\n  params.wgrad_coeff = 1.0 / bn_group;\n\n  dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);\n  _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/cuda_utils.h",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#ifndef CUDA_UTILS_H\n#define CUDA_UTILS_H\n\nnamespace at {\nnamespace cuda {\n\nnamespace utils {\n\nstatic inline int MaxSharedMemoryPerMultiprocessor(int device_id) {\n    return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;\n}\n\n\n}\n}\n}\n\n\n#endif\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/interface.cpp",
    "content": "#include <pybind11/pybind11.h>\n#include <pybind11/numpy.h>\n#include <pybind11/stl.h>\n\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <ATen/ArrayRef.h>\n#include <ATen/ScalarType.h>\n#include \"ATen/Scalar.h\"\n#ifndef VERSION_GE_1_1\n#include \"ATen/Type.h\"\n#endif\n#include \"ATen/Tensor.h\"\n#include \"ATen/Storage.h\"\n#include \"ATen/Generator.h\"\n\n\nnamespace py = pybind11;\n\nint64_t get_buffer_size(\n                       const int bn_sync_steps);\n\nvoid* get_data_ptr(\n                       const at::Tensor& data);\n\nvoid* get_remote_data_ptr(\n                       const at::Tensor& handle,\n                       const int64_t offset);\n\nvoid close_remote_data(\n                       const at::Tensor& handle);\n\nat::Tensor nhwc_bn_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nat::Tensor nhwc_bn_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu);\n\nstd::vector<at::Tensor> nhwc_bn_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       const bool fuse_relu,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nat::Tensor nhwc_bn_addrelu_fwd_train(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nat::Tensor nhwc_bn_addrelu_fwd_eval(\n                       const at::Tensor& x,\n                       const at::Tensor& z,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& ret_cta,\n                       const int bn_group,\n                       const float momentum,\n                       const float epsilon);\n\nstd::vector<at::Tensor> nhwc_bn_addrelu_bwd(\n                       const at::Tensor& x,\n                       const at::Tensor& dy,\n                       const at::Tensor& scale,\n                       const at::Tensor& bias,\n                       const at::Tensor& running_mean,\n                       const at::Tensor& running_inv_var,\n                       const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var,\n                       const at::Tensor& bitmask,\n                       const at::Tensor& ret_cta,\n                       const float momentum,\n                       const float epsilon,\n                       void* my_data,\n                       void* pair_data,\n                       void* pair_data2,\n                       void* pair_data3,\n                       const int bn_group,\n                       const at::Tensor& magic_tensor,\n                       const int occupancy,\n                       const int grid_dim_x,\n                       const bool coop);\n\nint nhwc_bn_fwd_occupancy();\nint nhwc_bn_bwd_occupancy();\n\nint nhwc_bn_addrelu_fwd_occupancy();\nint nhwc_bn_addrelu_bwd_occupancy();\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n\n  m.def(\"get_buffer_size\", &get_buffer_size, \"get_buffer_size\");\n  m.def(\"get_data_ptr\", &get_data_ptr, \"get_data_ptr\");\n  m.def(\"get_remote_data_ptr\", &get_remote_data_ptr, \"get_remote_data_ptr\");\n  m.def(\"close_remote_data\", &close_remote_data, \"close_remote_data\");\n\n  m.def(\"bn_fwd_nhwc\", &nhwc_bn_fwd_train, \"bn_fwd_nhwc\");\n  m.def(\"bn_fwd_eval_nhwc\", &nhwc_bn_fwd_eval, \"bn_fwd_eval_nhwc\");\n  m.def(\"bn_bwd_nhwc\", &nhwc_bn_bwd, \"bn_bwd_nhwc\");\n\n  m.def(\"bn_fwd_nhwc_occupancy\", &nhwc_bn_fwd_occupancy, \"bn_fwd_nhwc_occupancy\");\n  m.def(\"bn_bwd_nhwc_occupancy\", &nhwc_bn_bwd_occupancy, \"bn_bwd_nhwc_occupancy\");\n\n  m.def(\"bn_addrelu_fwd_nhwc\", &nhwc_bn_addrelu_fwd_train, \"bn_addrelu_fwd_nhwc\");\n  m.def(\"bn_addrelu_fwd_eval_nhwc\", &nhwc_bn_addrelu_fwd_eval, \"bn_addrelu_fwd_eval_nhwc\");\n  m.def(\"bn_addrelu_bwd_nhwc\", &nhwc_bn_addrelu_bwd, \"bn_addrelu_bwd_nhwc\");\n\n  m.def(\"bn_addrelu_fwd_nhwc_occupancy\", &nhwc_bn_addrelu_fwd_occupancy, \"bn_addrelu_fwd_nhwc_occupancy\");\n  m.def(\"bn_addrelu_bwd_nhwc_occupancy\", &nhwc_bn_addrelu_bwd_occupancy, \"bn_addrelu_bwd_nhwc_occupancy\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/ipc.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCNumerics.cuh>\n\n#include \"THC/THC.h\"\n\n#include <cuda.h>\n\n#include \"compat.h\"\n\n\n#define cudaCheckErrors(msg) \\\n    do { \\\n        cudaError_t __err = cudaGetLastError(); \\\n        if (__err != cudaSuccess) { \\\n            fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", \\\n                msg, cudaGetErrorString(__err), \\\n                __FILE__, __LINE__); \\\n            fprintf(stderr, \"*** FAILED - ABORTING\\n\"); \\\n            exit(1); \\\n        } \\\n    } while (0)\n\ntemplate<>\nstruct std::hash<cudaIpcMemHandle_t> {\n  size_t operator() (const cudaIpcMemHandle_t& handle) const {\n    size_t hash = 0;\n    uint8_t* ptr = (uint8_t*)&handle;\n    assert(sizeof(uint8_t) == 1);\n    for (int i=0; i<sizeof(cudaIpcMemHandle_t); i++) {\n      hash += *ptr;\n      ptr++;\n    }\n    return hash;\n  }\n};\n\ntemplate<>\nstruct std::equal_to<cudaIpcMemHandle_t> {\n  bool operator() (const cudaIpcMemHandle_t &lhs,\n                             const cudaIpcMemHandle_t &rhs) const {\n    return (std::memcmp((void*) &lhs,\n                        (void*) &rhs,\n                        sizeof(cudaIpcMemHandle_t)) == 0);\n  }\n};\n\nnamespace {\n\nnamespace gpuipc {\n//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h\n// The number of threads per pixel.\nconst int THREADS_PER_PIXEL = 16;\n// The number of elements per ldg.\nconst int ELEMENTS_PER_LDG = 4;\n// The number of reducing ops, each uses its own space : mean, var, dscale, dbias\nconst int REDUCE_OPS = 4;\n// Maximum block.y supported - limited due to buffer allocation\nconst int MAX_BLOCK_Y = 256;\nconst int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;\nconst int BYTES_PER_ELEM = 4;\n// Buffer size per sync step\nconst int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*2*ELEMENTS_PER_LDG*BYTES_PER_ELEM;\n};\n\nclass IpcMemHandleRegistry {\npublic:\n  void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {\n    if (registry_.count(handle) == 0) {\n      registry_.insert(std::make_pair(handle, RegistryEntry()));\n      registry_[handle].dev_ptr = ipcOpenMem(handle);\n    }\n    registry_[handle].ref_count++;\n    return (((uint8_t*)registry_[handle].dev_ptr) + offset);\n  }\n\n  void releasePtr(const cudaIpcMemHandle_t& handle) {\n    if (registry_.count(handle) == 0) {\n    }\n    if (--registry_[handle].ref_count == 0) {\n      ipcCloseMem(registry_[handle].dev_ptr);\n      registry_.erase(handle);\n    }\n  }\n\n  struct RegistryEntry {\n    void* dev_ptr;\n    int   ref_count;\n    RegistryEntry() : dev_ptr(NULL) , ref_count(0) {}\n  };\n\nprotected:\n  std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;\n\n  void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {\n    void *data;\n    cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);\n    cudaCheckErrors(\"ipc init\");\n    return data;\n  }\n\n  void ipcCloseMem(void* dev_ptr) {\n    cudaIpcCloseMemHandle(dev_ptr);\n    cudaCheckErrors(\"ipc close\");\n  }\n\n};\n\n}\n\nstatic IpcMemHandleRegistry ipc_mem_registry;\n\nint64_t get_buffer_size(const int bn_sync_steps) {\n  return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES;\n}\n\nvoid* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {\n  cudaIpcMemHandle_t my_handle;\n  memcpy((unsigned char *)(&my_handle), handle.DATA_PTR<uint8_t>(), sizeof(my_handle));\n  return ipc_mem_registry.getPtr(my_handle, offset);\n}\n\nvoid close_remote_data(const at::Tensor& handle) {\n    cudaIpcMemHandle_t my_handle;\n    memcpy((unsigned char *)(&my_handle), handle.DATA_PTR<uint8_t>(), sizeof(my_handle));\n  ipc_mem_registry.releasePtr(my_handle);\n}\n\nvoid* get_data_ptr(\n                   const at::Tensor& data) {\n  return data.DATA_PTR<uint8_t>();\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm_kernel.h\n * \\brief CUDA NHWC Batch Normalization code\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer\n*/\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n\n#include <stdint.h>\n#include <algorithm>\n\n#define DEVICE_FUNCTION static inline __device__\n\n// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.\n#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN     3\n#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< typename T, int ELEMENTS_PER_LDG >\nstruct PackedStorage {\n    enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG };\n    typedef T Type;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int ELEMENTS_PER_LDG >\nstruct PackedStorage<uint16_t, ELEMENTS_PER_LDG> {\n    enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG/2 };\n    typedef int Type;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        uint16_t lo, hi;\n        asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(lo) : \"f\"(src[2*i+0]));\n        asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(hi) : \"f\"(src[2*i+1]));\n        asm volatile(\"mov.b32 %0, {%1, %2};\"  : \"=r\"(dst[i]) : \"h\"(lo), \"h\"(hi));\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = src[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        uint16_t lo, hi;\n        asm volatile(\"mov.b32 {%0, %1}, %2;\" : \"=h\"(lo), \"=h\"(hi) : \"r\"(src[i]));\n        asm volatile(\"cvt.f32.f16 %0, %1;\"   : \"=f\"(dst[2*i+0])   : \"h\"(lo));\n        asm volatile(\"cvt.f32.f16 %0, %1;\"   : \"=f\"(dst[2*i+1])   : \"h\"(hi));\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = src[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) {\n    dst[0] = __ldg((const int*) gmem);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) {\n    unsigned int tmp;\n    asm volatile (\"ld.global.cs.nc.s32 %0, [%1];\"  : \"=r\"(tmp) : \"l\" ((const uint *)gmem));\n    dst[0] = tmp;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) {\n    int2 tmp = __ldg((const int2*) gmem);\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) {\n    int2 tmp;\n    asm volatile (\"ld.global.cs.nc.v2.s32 {%0,%1}, [%2];\"\n        : \"=r\"(tmp.x), \"=r\"(tmp.y) : \"l\"((const int2 *)gmem));\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t *gmem) {\n    int tmp[N/2];\n    ldg(tmp, gmem);\n    to_float(dst, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t *gmem) {\n    int tmp[N/2];\n    ldg_stream(tmp, gmem);\n    to_float(dst, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) {\n    reinterpret_cast<int*>(gmem)[0] = src[0];\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) {\n    unsigned int tmp = src[0];\n    asm volatile (\"st.global.cs.s32 [%0], %1;\"\n        :: \"l\"((uint *)gmem) , \"r\"(tmp));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) {\n    reinterpret_cast<int2*>(gmem)[0] = make_int2(src[0], src[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) {\n    asm volatile (\"st.global.cs.v2.s32 [%0], {%1,%2};\"\n        :: \"l\"((uint *)gmem) , \"r\"(src[0]), \"r\"( src[1]));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[N]) {\n    int tmp[N/2];\n    from_float(tmp, src);\n    stg(gmem, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) {\n    int tmp[N/2];\n    from_float(tmp, src);\n    stg_stream(gmem, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) {\n    float2 tmp = __ldg(reinterpret_cast<const float2*>(&gmem[2*idx]));\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) {\n    float4 tmp = __ldg(reinterpret_cast<const float4*>(&gmem[4*idx]));\n    dst[0] = tmp.x;\n    dst[1] = tmp.y;\n    dst[2] = tmp.z;\n    dst[3] = tmp.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) {\n    float2 tmp = *(const float2*) &smem[2*idx];\n    x[0] = tmp.x;\n    x[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) {\n    x[0] = smem[idx];\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) {\n    float4 tmp = *(const float4*) &smem[4*idx];\n    x[0] = tmp.x;\n    x[1] = tmp.y;\n    x[2] = tmp.z;\n    x[3] = tmp.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) {\n    int2 tmp = *(const int2*) &smem[2*idx];\n    x[0] = tmp.x;\n    x[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) {\n    reinterpret_cast<float2*>(&gmem[2*idx])[0] = make_float2(src[0], src[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) {\n    reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) {\n    reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) {\n    reinterpret_cast<float2*>(&smem[2*idx])[0] = make_float2(x[0], x[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) {\n    smem[idx] = x[0];\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) {\n    reinterpret_cast<float4*>(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) {\n    reinterpret_cast<int2*>(&smem[2*idx])[0] = make_int2(x[0], x[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void zero_array(int (&dst)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = 0;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int N >\nDEVICE_FUNCTION void zero_array(float (&dst)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        dst[i] = 0.f;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] += y[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] *= y[i];\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void scale_(float (&x)[N], float scalar) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] *= scalar;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N],\n                               const float (&scale)[N], const float (&m1)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] = bias[i] + scale[i] * (x[i] - m1[i]);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Storage>\nDEVICE_FUNCTION Storage relu(Storage in) {\n    Storage zero = (Storage)0.f;\n    return (in < zero)? zero : in;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_activation(float (&x)[N]) {\n    #pragma unroll\n    for (int i = 0; i < N; ++i) {\n        x[i] = relu(x[i]);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\ntemplate< int THREADS_PER_CTA >\nDEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,\n                                        void* params_my_data, void** params_pair_datas, int off,\n                                        const int magic,\n                                        const int sync_iters) {\n    // The size of a warp.\n    const int THREADS_PER_WARP = 32;\n    // The number of warps in a CTA.\n    const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n    // The number of threads per pixel.\n    const int THREADS_PER_PIXEL = 16;\n    // The number of elements per ldg.\n    const int ELEMENTS_PER_LDG = 4;\n    // The number of reducing ops, each uses its own space : mean, var, dscale, dbias\n    const int REDUCE_OPS = 4;\n    // Maximum block.y supported - limited due to buffer allocation\n    const int MAX_BLOCK_Y = 256;\n    const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;\n    // The warp decomposition.\n    const int warp_id = threadIdx.x / THREADS_PER_WARP;\n    const int lane_id = threadIdx.x % THREADS_PER_WARP;\n    // total size of data per sync iter\n    const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2;\n\n    #pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n    }\n\n    // The warp leaders, write to SMEM.\n    if (lane_id < THREADS_PER_PIXEL) {\n        write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);\n    }\n\n    // The data is in SMEM. Do the final reduction.\n    __syncthreads();\n\n    // The 1st warp does all the work.\n    // We do the final reduction each half-warp sequentially reduces the final values.\n    if (warp_id == 0) {\n        read_from_smem(x, smem, threadIdx.x);\n\n        #pragma unroll\n        for (int offset = 1;\n             offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {\n            float y[ELEMENTS_PER_LDG];\n            // Read the mean and variance from the other pixel.\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);\n            // Compute the updated sum.\n            add(x, y);\n        }\n\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n        }\n\n        // Make sure the data was read from SMEM.\n        __syncwarp();\n\n        // Store the final values.\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n        // probably could do it earlier, before sync\n\n        for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) {\n            //float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];\n            void* params_pair_data = params_pair_datas[sync_iter];\n\n            // skip the space consumed by previous sync iterations\n            const int xbuf_offset = sync_iter*data_total;\n            // data starts after flags, but have to skip previous\n            const int data_offset = xbuf_offset\n                                    + off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL*2\n                                    + ELEMENTS_PER_LDG*threadIdx.x*2;\n\n            // after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU\n            if (blockIdx.x == 0) {\n                volatile float * write_data =\n                    &((reinterpret_cast<float*>(params_pair_data))[data_offset]);\n\n                // write the data to memory region to be reflected to other GPU\n                asm volatile (\"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};\"\n                    :: \"l\"(write_data) , \"f\"(x[0]), \"r\"(magic), \"f\"(x[2]), \"r\"(magic));\n\n                asm volatile (\"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};\"\n                    :: \"l\"(write_data+4) , \"f\"(x[1]), \"r\"(magic), \"f\"(x[3]), \"r\"(magic));\n            }\n\n            // now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU\n            volatile float * read_data =\n                &((reinterpret_cast<float*>(params_my_data))[data_offset]);\n\n            float other[4];\n            uint32_t other_flag_a, other_flag_b;\n            do {\n                asm volatile (\"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];\"\n                    : \"=f\"(other[0]), \"=r\"(other_flag_a), \"=f\"(other[2]), \"=r\"(other_flag_b) : \"l\"(read_data));\n            } while ((other_flag_a != magic) || (other_flag_b != magic));\n\n            do {\n                asm volatile (\"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];\"\n                    : \"=f\"(other[1]), \"=r\"(other_flag_a), \"=f\"(other[3]), \"=r\"(other_flag_b) : \"l\"(read_data+4));\n            } while ((other_flag_a != magic) || (other_flag_b != magic));\n\n            add(x, other);\n        }\n        // finally, after syncing up and accounting for partial sums from\n        // other GPUs as required, write the result\n\n\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int THREADS_PER_CTA >\nDEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {\n    // The size of a warp.\n    const int THREADS_PER_WARP = 32;\n    // The number of warps in a CTA.\n    const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n    // The number of threads per pixel.\n    const int THREADS_PER_PIXEL = 8;\n    // The number of elements per ldg.\n    const int ELEMENTS_PER_LDG = 4;\n    // The warp decomposition.\n    const int warp_id = threadIdx.x / THREADS_PER_WARP;\n    const int lane_id = threadIdx.x % THREADS_PER_WARP;\n\n    #pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n        x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);\n    }\n\n    // The warp leaders, write to SMEM.\n    if (lane_id < THREADS_PER_PIXEL) {\n        write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);\n    }\n\n    // The data is in SMEM. Do the final reduction.\n    __syncthreads();\n\n    // The 1st warp does all the work.\n    // We do the final reduction each half-warp sequentially reduces the final values.\n    if (warp_id == 0) {\n        read_from_smem(x, smem, threadIdx.x);\n\n        #pragma unroll\n        for (int offset = 1;\n             offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {\n            float y[ELEMENTS_PER_LDG];\n            // Read the mean and variance from the other pixel.\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);\n            // Compute the updated sum.\n            add(x, y);\n        }\n\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);\n            x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);\n        }\n\n        // Make sure the data was read from SMEM.\n        __syncwarp();\n\n        // Store the final values.\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >\nDEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {\n    // The size of a warp.\n    const int THREADS_PER_WARP = 32;\n    // The number of warps in a CTA.\n    const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n    // The number of pixels computed by a single warp.\n    const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL;\n\n    // The position in the warp.\n    const int nhw_in_warp = nhw % PIXELS_PER_WARP;\n    // The C in the warp.\n    const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL;\n\n    // Store the values to shared memory.\n    write_to_smem(smem, threadIdx.x, x);\n\n    // Compute the parallel sums.\n    for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) {\n        // NOP.\n        __syncwarp();\n\n        // Read the running sum from the other thread.\n        float y[ELEMENTS_PER_LDG];\n        if (nhw_in_warp < offset) {\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);\n        }\n\n        // Compute the updated sum.\n        add(x, y);\n\n        // NOP.\n        __syncwarp();\n\n        // Update the sum in SMEM.\n        if (offset > 1 && nhw_in_warp < offset) {\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n\n    // The warps are done. Do the final reduction at the CTA level.\n    __syncthreads();\n\n    // The warp leaders, write to SMEM.\n    const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp;\n    if (nhw_in_warp == 0) {\n        write_to_smem(smem, idx, x);\n    }\n\n    // The data is in SMEM. Do the final reduction.\n    __syncthreads();\n\n    // Read the 1st element to prepare the work.\n    if (nhw < WARPS_PER_CTA/2) {\n        read_from_smem(x, smem, threadIdx.x);\n    }\n\n    // We have the running mean and running m2. Let's build the mean/var of the CTA.\n    for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) {\n        // NOP.\n        __syncwarp();\n\n        // Read the mean and variance from the other pixel.\n        float y[ELEMENTS_PER_LDG];\n        if (nhw < offset) {\n            read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);\n        }\n\n        // Compute the updated sum.\n        add(x, y);\n\n        // NOP.\n        __syncwarp();\n\n        // Store the mean/var for the different pixels.\n        if (nhw < offset) {\n            write_to_smem(smem, threadIdx.x, x);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate< int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >\nstruct ParallelSums {\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {\n        parallel_sums<THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG>(smem, x, nhw);\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct ParallelSums<16, 4> {\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {\n        parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0);\n    }\n\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) {\n        parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters);\n    }\n};\n\ntemplate<>\nstruct ParallelSums<8, 4> {\n    template< int THREADS_PER_CTA >\n    DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {\n        parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline int div_up(int m, int n) {\n    return (m + n - 1) / n;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// It is expected that all threads in the CTA enter this function!\nDEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) {\n\n    // Register the CTA.\n    if (threadIdx.x == 0) {\n        // Issue the membar.\n        __threadfence();\n        // Notify that the CTA is done.\n        int val_to_add = 1;\n        if (master) {\n            val_to_add = -(expected_count - 1);\n        }\n        atomicAdd(gmem_retired_ctas, val_to_add);\n    }\n\n    // Are all CTAs done?\n    if (threadIdx.x == 0) {\n        int retired_ctas = -1;\n        do {\n            __threadfence();\n            asm volatile (\"ld.global.cg.b32 %0, [%1];\"\n                : \"=r\"(retired_ctas) : \"l\"(gmem_retired_ctas));\n        } while (retired_ctas != 0);\n    }\n    __syncthreads();\n\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormFwdInferenceParams {\n    // The input/output tensors.\n    uint16_t *gmem_src, *gmem_dst, *gmem_src1;\n    // the final mean and variance as calculated during the training process\n    float *gmem_mean, *gmem_var;\n    // The bias/scale.\n    float *gmem_bias, *gmem_scale;\n    // The dimensions.\n    int nhw, c;\n    // epsilon\n    float var_eps;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int ELEMENTS_PER_LDG,\n    bool USE_RELU,\n    bool USE_ADD_RELU\n>\n__global__ __launch_bounds__(THREADS_PER_CTA)\n    void nhwc_batch_norm_fwd_inference(NhwcBatchNormFwdInferenceParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // The start position in the NHW dimension where the CTA starts.\n    const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG;\n    // Compute the NHW coordinate of the thread in the CTA.\n    const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n    // thread's starting point in NHW\n    const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG;\n\n    // The position in the C dimension where the CTA starts.\n    const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA;\n    // Compute the C coordinate of the thread in the CTA.\n    const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n    // Compute the C coordinate of the thread.\n    const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n    // Is the thread working on a valid C dimension?\n    const int is_valid_c = thread_c < params.c;\n\n    float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG];\n    float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG];\n    zero_array(mean);\n    zero_array(var);\n    zero_array(scale);\n    zero_array(bias);\n    if (is_valid_c) {\n        read_from_gmem(var, &params.gmem_var[cta_c], thread_in_cta_c);\n        read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);\n        read_from_gmem(mean, &params.gmem_mean[cta_c], thread_in_cta_c);\n        read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);\n    }\n\n    // Update the scale with the stddev and eps.\n    #pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        scale[i] *= rsqrtf(var[i] + params.var_eps);\n    }\n\n    // The base pointers for reading/writing\n    uint16_t *const gmem_src = &params.gmem_src[thread_c];\n    uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n    const uint16_t *gmem_src1 = nullptr;\n    if (USE_ADD_RELU) {\n        gmem_src1 = &params.gmem_src1[thread_c];\n    }\n\n    // apply BN\n    for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) {\n        float x_math[ELEMENTS_PER_LDG];\n        zero_array(x_math);\n        if (is_valid_c) {\n            ldg(x_math, &gmem_src[nhw*params.c]);\n        }\n\n        // Normalize and apply activation function\n        normalize(x_math, bias, scale, mean);\n        if (USE_ADD_RELU) {\n            float x1_math[ELEMENTS_PER_LDG];\n            ldg(x1_math, &gmem_src1[nhw*params.c]);\n            add(x_math, x1_math);\n            relu_activation(x_math);\n        } else if (USE_RELU) {\n            relu_activation(x_math);\n        }\n\n        if (is_valid_c) {\n            stg(&gmem_dst[nhw*params.c], x_math);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormFwdParams {\n    // The input/output tensors.\n    uint16_t *gmem_src, *gmem_dst, *gmem_src1;\n    // The bias/scale.\n    float *gmem_bias, *gmem_scale;\n    // running mean/var (refer BN API from cudnn doc)\n    float *gmem_running_mean, *gmem_running_var;\n    // saved mean/var (refer BN API from cudnn doc)\n    float *gmem_saved_mean, *gmem_saved_var;\n    // ReLU bitmask\n    unsigned int *gmem_relu_bitmask;\n    // The dimensions.\n    int nhw, c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    float svar_inv_count;\n    // factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).\n    float rvar_inv_count;\n    // The buffer to do the reduction for mean, stddev and count.\n    float *gmem_sums;\n    // The buffer to count items in the different CTAs.\n    int *gmem_counts;\n    // The counters of retired CTAs.\n    int *gmem_retired_ctas;\n    // The epsilon to apply to the computation of the variance.\n    float var_eps;\n    // outer loop count\n    int outer_loops;\n    // exponential average factor\n    float exp_avg_factor;\n    // number of CTAs along .x dimension\n    int c_blks;\n\n    void* my_data;\n    void* pair_datas[4];\n    int magic;\n    int sync_iters;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    bool USE_RELU,\n    bool USE_ADD_RELU,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // Compute the NHW coordinate of the thread in the CTA.\n    const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n        // Clamp thread_c so that we load from valid locations even if we don't use the value\n        if (!is_valid_c)\n            thread_c = params.c - 4;\n\n        // Single pass numerically stable algorithm, see:\n        // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm\n        //\n        // n = 0, mean = 0.0, M2 = 0.0\n        //\n        // for x in data:\n        //     n += 1\n        //     delta = x - mean\n        //     mean += delta/n\n        //     delta2 = x - mean\n        //     M2 += delta*delta2\n        //\n        // if n < 2:\n        //     return float('nan')\n        // else:\n        //     return M2 / (n - 1)\n\n        // Register to store the number of elements read so far.\n        float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG];\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            mean[i] = 0.f;\n            m2[i] = 0.f;\n        }\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointer to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute the mean/var across those elements.\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized, offset is evenly divisible by 32\n            int offset = (pixels_per_iteration * OUTER_LOOPS +\n                          PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31;\n            cta_nhw_regs -= offset;\n            cta_nhw_smem -= offset;\n        }\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) -\n                                 max(nhw_regs, 0), 0);\n\n            // Load the data and compute the local mean/sum and the variance.\n            if (USE_ONLINE_APPROACH) {\n                // Read the elements from memory.\n                float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                    zero_array(x_storage[i]);\n                    is_valid[i] = 0.f;\n                    if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                        if (loop_i == OUTER_LOOPS - 1) {\n                            ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        } else {\n                            ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        }\n                        is_valid[i] = 1.f;\n                    }\n                }\n\n                // Do the math.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    // Convert to float.\n                    float x_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage[i]);\n\n                    // Update the count.\n                    count += is_valid[i];\n                    // Invert the count.\n                    float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                    // Update the mean and m2 using deltas.\n                    #pragma unroll\n                    for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                        float delta0 = x_math[j] - mean[j];\n                        mean[j] += delta0 * inv_count;\n                        float delta1 = x_math[j] - mean[j];\n                        m2[j] += delta0 * delta1 * is_valid[i];\n                    }\n                }\n            } else {\n                // Read the elements from memory.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                    zero_array(x_storage[i]);\n                    if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                        if (loop_i == OUTER_LOOPS - 1) {\n                            ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        } else {\n                            ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        }\n                        count += 1.f;\n                    }\n                }\n\n                // Sum the elements in registers.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    // Convert to float.\n                    float x_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage[i]);\n\n                    // Update the mean and m2 using deltas.\n                    #pragma unroll\n                    for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                        mean[j] += x_math[j];\n                    }\n                }\n\n                // Compute the mean.\n                float inv_count = 1.f / count;\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    mean[j] *= inv_count;\n                }\n\n                // Compute the variance.\n                #pragma unroll\n                for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                    // Convert to float.\n                    float x_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage[i]);\n\n                    // Is it a valid pixel?\n                    float is_valid = i < static_cast<int>(count) ? 1.f : 0.f;\n                    // Update the mean and m2 using deltas.\n                    #pragma unroll\n                    for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                        m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid;\n                    }\n                }\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                float is_pixel_valid = (((unsigned int)idx <\n                                         (unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f;\n\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];\n                ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0)*params.c]);\n\n                // The offset to store in SMEM.\n                const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                // Update the mean and m2 using deltas.\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    float delta0 = x_math[j] - mean[j];\n                    mean[j] += delta0 * inv_count;\n                    float delta1 = x_math[j] - mean[j];\n                    m2[j] += delta0 * delta1 * is_pixel_valid;\n                }\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        float m1[ELEMENTS_PER_LDG];\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m1[i] = mean[i] * count;\n        }\n\n        // Run the parallel sum accross the CTA to get the local sum.\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, m1, thread_in_cta_nhw);\n        __syncthreads();\n\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(m1, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // Adjust the variance.\n        float inv_cta_count = 1.f / static_cast<float>(cta_count);\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            float mean_diff = m1[i]*inv_cta_count - mean[i];\n            m2[i] = m2[i] + mean_diff * mean_diff * count;\n        }\n\n        // Run the parallel sum accross the CTA to get the local adjusted variance.\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, m2, thread_in_cta_nhw);\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, m1);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, m2);\n        }\n\n        // The memory location to store the number of pixels per CTA.\n        int *gmem_counts = &params.gmem_counts[c_blk_index*gridDim.x];\n        if (threadIdx.x == 0) {\n            gmem_counts[blockIdx.x] = cta_count;\n        }\n\n        // Read the bias and scale.\n        float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG];\n        if (is_valid_c) {\n            read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);\n            read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the mean to compute the global mean.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m1[i] = 0.f;\n        }\n\n        // Build the global mean.\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp, gmem_sums, idx);\n            add(m1, tmp);\n        }\n\n        if (params.sync_iters>0)\n        {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, m1, thread_in_cta_nhw);\n        }\n        __syncthreads();\n\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(m1, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // Normalize the mean.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m1[i] = m1[i] * params.svar_inv_count;\n        }\n\n        // Reset the variance.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            m2[i] = 0.f;\n        }\n\n        // for add+relu fusion\n        const uint16_t *gmem_src1 = nullptr;\n        if (USE_ADD_RELU) {\n            gmem_src1 = &params.gmem_src1[thread_c];\n        }\n\n        // Build the global variance.\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            // Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.\n            float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp_mean, &gmem_sums[                           0], idx);\n            read_from_gmem(tmp_var,  &gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx);\n\n            // Read the number of pixels visited by a given CTA.\n            cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]);\n\n            // Compute the diff to update the variance.\n            float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast<float>(cta_count);\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                mean_diff[i] = m1[i] - tmp_mean[i]*inv_cta_count;\n            }\n\n            // Update the variance.\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                m2[i] += tmp_var[i] + mean_diff[i]*mean_diff[i]*static_cast<float>(cta_count);\n            }\n        }\n\n        if (params.sync_iters>0)\n        {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, m2, thread_in_cta_nhw);\n        }\n        __syncthreads();\n\n        read_from_smem(m2, smem, thread_in_cta_c);\n\n        // Finalize the stddev.\n        // becasue saved var and running var may have different denominator, we don't do it here\n        // scale_(m2, inv_count);\n\n        // store the saved mean/var\n        float svarinv[ELEMENTS_PER_LDG];\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps);\n        }\n        if (is_valid_for_saving) {\n            write_to_gmem(params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG, m1);\n            write_to_gmem(params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG, svarinv);\n        }\n\n        // store the running mean/var\n        float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG];\n        zero_array(rmean);\n        zero_array(rvar);\n        if (params.exp_avg_factor != 1.f && is_valid_for_saving) {\n            read_from_gmem(rmean, params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG);\n            read_from_gmem(rvar, params.gmem_running_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] +   \\\n                params.exp_avg_factor * m1[i];\n            rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] +     \\\n                params.exp_avg_factor * (m2[i] * params.rvar_inv_count);\n        }\n        if (is_valid_for_saving) {\n            write_to_gmem(params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG, rmean);\n            write_to_gmem(params.gmem_running_var, thread_c/ELEMENTS_PER_LDG, rvar);\n        }\n\n        // Update the scale with the stddev and eps.\n        multiply(scale, svarinv);\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +\n                                     ((params.nhw + 31) & ~31) * 2 * c_blk_index;\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                const bool is_valid = is_valid_nhw && is_valid_c;\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n\n                // Normalize and apply activation function\n                normalize(x_math, bias, scale, m1);\n                if (USE_ADD_RELU) {\n                    float x1_math[ELEMENTS_PER_LDG];\n                    ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);\n                    add(x_math, x1_math);\n                    unsigned int relu_mask;\n                    int lane_id = threadIdx.x & 31;\n                    #pragma unroll\n                    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                        bool rectified = x_math[i] < 0.0F;\n                        unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);\n                        if (lane_id == i) {\n                            // Thread 0 remembers the relu_mask from the first time through this\n                            // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.\n                            relu_mask = local_relu_mask;\n                        }\n                        if (rectified) {\n                            x_math[i] = 0.0F;\n                        }\n                    }\n                    if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {\n                        gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;\n                    }\n                } else if (USE_RELU) {\n                    relu_activation(x_math);\n                }\n\n                // Write back.\n                if (is_valid) {\n                    stg_stream(&gmem_dst[idx*params.c], x_math);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            #pragma unroll 2\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                const bool is_valid = is_valid_nhw && is_valid_c;\n\n                // Read from SMEM.\n                const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];\n                read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                float x_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n\n                // Normalize and apply activation function\n                normalize(x_math, bias, scale, m1);\n                if (USE_ADD_RELU) {\n                    float x1_math[ELEMENTS_PER_LDG];\n                    ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);\n                    add(x_math, x1_math);\n                    unsigned int relu_mask;\n                    int lane_id = threadIdx.x & 31;\n                    #pragma unroll\n                    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                        bool rectified = x_math[i] < 0.0F;\n                        unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);\n                        if (lane_id == i) {\n                            relu_mask = local_relu_mask;\n                        }\n                        if (rectified) {\n                            x_math[i] = 0.0F;\n                        }\n                    }\n                    if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {\n                        gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;\n                    }\n                } else if (USE_RELU) {\n                    relu_activation(x_math);\n                }\n\n                // Write back.\n                if (is_valid) {\n                    stg_stream(&gmem_dst[idx*params.c], x_math);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormBwdParams {\n    // The input/output tensors.\n    uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1;\n    // dscale/dbias\n    float *gmem_dscale, *gmem_dbias;\n    // The scale and bias.\n    float *gmem_scale, *gmem_bias;\n    // The mean/inv-var saved from fwd pass\n    float *gmem_saved_mean, *gmem_saved_var;\n    // ReLU bitmask\n    unsigned int *gmem_relu_bitmask;\n    // The dimensions.\n    int nhw, c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    float svar_inv_count;\n    // The buffer to do the reduction for dscale and dbias\n    float *gmem_sums;\n    // The counters of retired CTAs.\n    int *gmem_retired_ctas;\n    // outer loop count\n    int outer_loops;\n    // number of CTAs along .x dimension\n    int c_blks;\n\n    void* my_data;\n    void* pair_datas[4];\n    int magic;\n    int sync_iters;\n    float wgrad_coeff;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N],\n                              const float (&mean_var_scale_bias)[N],\n                              const float (&var_scale)[N], bool valid_data) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];\n        if ((y <= 0.f) && valid_data) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        if ((y[j] <= 0.f) && valid_data) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        if (rectified[j] && valid_data) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N],\n                                     const float (&x)[N],\n                                     const float (&mean_var_scale_bias)[N],\n                                     const float (&var_scale)[N]) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];\n        if (y <= 0.f) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        if (y[j] <= 0.f) {\n            dy[j] = 0.f;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N],\n                                const float (&dy)[N], const float (&x)[N],\n                                const float (&mean)[N], float inv_count) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float delta0 = dy[j] - dbias[j];\n        dbias[j] += delta0 * inv_count;\n        delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j];\n        dscale[j] += delta0 * inv_count;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N],\n                            const float (&var)[N], const float (&x)[N], const float (&mean)[N],\n                            const float (&dscale)[N], const float (&dbias)[N], float inv_count) {\n    #pragma unroll\n    for (int j = 0; j < N; ++j) {\n        float tmp1 = dy[j] - (dbias[j]* inv_count);\n        float tmp2 = dscale[j] * inv_count;\n        float tmp3 = x[j] - mean[j];\n        dx[j] = var[j] * (tmp1 - (tmp2 * tmp3));\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n    PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n        // Compute the NHW coordinate of the thread in the CTA.\n        const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n        // Registers to store the mean used for entire duration\n        float mean[ELEMENTS_PER_LDG];\n        zero_array(mean);\n        if (is_valid_c) {\n            read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // accumulation related registers\n        float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointers to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n        const uint16_t *gmem_dy = &params.gmem_dy[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute sum across them\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized\n            int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -\n                         PIXELS_PER_CTA_IN_SMEM * gridDim.x;\n            cta_nhw_regs += offset;\n            cta_nhw_smem += offset;\n        }\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));\n\n            // Read the elements from memory.\n            float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                zero_array(x_storage[i]);\n                zero_array(dy_storage[i]);\n                is_valid[i] = 0.f;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    if (loop_i == OUTER_LOOPS - 1) {\n                        ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                    } else {\n                        ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg(dy_storage[i], &gmem_dy[idx*params.c]);\n                    }\n                    is_valid[i] = 1.f;\n                }\n            }\n\n            // Do the math.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float and update\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                // Update the count.\n                count += is_valid[i];\n                // Invert the count.\n                float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                bool is_pixel_valid = (((unsigned int)idx <\n                                        (unsigned int)params.nhw) && is_valid_c);\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                                  dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                zero_array(x_storage_local);\n                zero_array(dy_storage_local);\n                if (is_pixel_valid) {\n                    ldg_stream(x_storage_local, &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);\n                }\n\n                // The offset to store in SMEM.\n                int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                to_float(dy_math, dy_storage_local);\n\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            dbias[i] *= count;\n            dscale[i] *= count;\n        }\n\n        // dscale parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dscale, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dbias, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, dscale);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the accumulators for global summation\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // Build the global accumulation\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp1, gmem_sums,                              idx);\n            read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);\n\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                dscale[i] += tmp1[i];\n                dbias[i] += tmp2[i];\n            }\n        }\n\n        // dscale parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n\n        // inv-var\n        float var[ELEMENTS_PER_LDG];\n        zero_array(var);\n        if (is_valid_c) {\n            read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // Normalize the dscale.\n        multiply(dscale, var);\n\n        // store dscale/dbias\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        if (is_valid_for_saving) {\n            if (params.sync_iters>0)\n            {\n                scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n                scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n            } else {\n                write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);\n                write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);\n            }\n        }\n\n        // scale\n        float scale[ELEMENTS_PER_LDG];\n        zero_array(scale);\n        if (is_valid_c) {\n            read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // Further normalize the dscale to be used in dx calculation\n        multiply(dscale, var);\n        // scale the inv-var as well, afterwards\n        multiply(var, scale);\n\n        // inverse count\n        float inv_count = params.svar_inv_count;\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                float dx[ELEMENTS_PER_LDG];\n                bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                // Write back.\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                if (is_valid) {\n                    // Read from SMEM.\n                    int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                        dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                    read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage_local);\n                    to_float(dy_math, dy_storage_local);\n\n                    float dx[ELEMENTS_PER_LDG];\n                    bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                    // Write back.\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n    PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n        // Compute the NHW coordinate of the thread in the CTA.\n        const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n\n        // Registers to store the mean/var/scale/bias used for the entire duration\n        // Register usage optimizations:\n        // 1. Can combine bias - (mean * var * scale) into a single register\n        // 2. Can combine var * scale into a single register\n        float varscale[ELEMENTS_PER_LDG];\n        zero_array(varscale);\n        if (is_valid_c) {\n            read_from_gmem(varscale, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        float tmp[ELEMENTS_PER_LDG];\n        zero_array(tmp);\n        if (is_valid_c) {\n            read_from_gmem(tmp, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(varscale, tmp);\n        float mean[ELEMENTS_PER_LDG];\n        zero_array(mean);\n        if (is_valid_c) {\n            read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);\n        }\n        zero_array(tmp);\n        if (is_valid_c) {\n            read_from_gmem(tmp, params.gmem_bias, thread_c/ELEMENTS_PER_LDG);\n        }\n        float mean_var_scale_bias[ELEMENTS_PER_LDG];\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]);\n        }\n\n        // accumulation related registers\n        float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointers to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n        const uint16_t *gmem_dy = &params.gmem_dy[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute sum across them\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized\n            int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -\n                         PIXELS_PER_CTA_IN_SMEM * gridDim.x;\n            cta_nhw_regs += offset;\n            cta_nhw_smem += offset;\n        }\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));\n\n            // Read the elements from memory.\n            float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                zero_array(x_storage[i]);\n                zero_array(dy_storage[i]);\n                is_valid[i] = 0.f;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    if (loop_i == OUTER_LOOPS - 1) {\n                        ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                    } else {\n                        ldg(x_storage[i], &gmem_src[idx*params.c]);\n                        ldg(dy_storage[i], &gmem_dy[idx*params.c]);\n                    }\n                    is_valid[i] = 1.f;\n                }\n            }\n\n            // Do the math.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float and update\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                // Update the count.\n                count += is_valid[i];\n                // Invert the count.\n                float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                bool is_pixel_valid = (((unsigned int)idx <\n                                        (unsigned int)params.nhw) && is_valid_c);\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                                  dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                zero_array(x_storage_local);\n                zero_array(dy_storage_local);\n                if (is_pixel_valid) {\n                    ldg_stream(x_storage_local, &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);\n                }\n\n                // The offset to store in SMEM.\n                int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                to_float(dy_math, dy_storage_local);\n\n                relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            dbias[i] *= count;\n            dscale[i] *= count;\n        }\n\n        // dscale parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dscale, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dbias, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, dscale);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the accumulators for global summation\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // Build the global accumulation\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp1, gmem_sums,                              idx);\n            read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);\n\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                dscale[i] += tmp1[i];\n                dbias[i] += tmp2[i];\n            }\n        }\n\n        // dscale parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n\n        // Normalize the dscale.\n        float var[ELEMENTS_PER_LDG];\n        zero_array(var);\n        if (is_valid_c) {\n            read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n\n        // store dscale/dbias\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        if (is_valid_for_saving) {\n            if (params.sync_iters>0)\n            {\n                scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n                scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n            } else {\n                write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);\n                write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);\n            }\n        }\n\n        // Further normalize the dscale to be used in dx calculation\n        float scale[ELEMENTS_PER_LDG];\n        zero_array(scale);\n        if (is_valid_c) {\n            read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n        // scale the inv-var as well, afterwards\n        multiply(var, scale);\n\n        // inverse count\n        float inv_count = params.svar_inv_count;\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n                relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);\n\n                float dx[ELEMENTS_PER_LDG];\n                bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                // Write back.\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                if (is_valid) {\n                    // Read from SMEM.\n                    int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                        dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                    read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage_local);\n                    to_float(dy_math, dy_storage_local);\n                    relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);\n\n                    float dx[ELEMENTS_PER_LDG];\n                    bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                    // Write back.\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    typename Storage,\n    int THREADS_PER_CTA,\n    int THREADS_PER_PIXEL,\n    int PIXELS_PER_THREAD_IN_REGISTERS,\n    int PIXELS_PER_THREAD_IN_SMEM,\n    int ELEMENTS_PER_LDG,\n    int USE_ONLINE_APPROACH,\n    int OUTER_LOOPS_,\n    int DESIRED_OCCUPANCY\n>\n__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)\n    void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) {\n    // The number of pixels loaded in a single LDG.\n    const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n    // The number of pixels computed per CTA stored in registers.\n    const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n    // The number of pixels computed per CTA stored in SMEM.\n    const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;\n    // The number of C elements per CTA.\n    const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;\n\n    // Shared memory to do CTA-wide parallel sums.\n    __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];\n\n    // The adapter for the storage.\n    typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n    // The data type for packed storage in SMEM.\n    typedef typename PackedStorage_::Type PackedStorageType;\n    // The number of elements in the packed storage.\n    const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n    // Registers to keep the data live for the persistent approach.\n    PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n    PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n    // Shared memory buffer to store the extra pixels.\n    extern __shared__ PackedStorageType smem_storage_packed[];\n\n    for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n        // The position in the NHW dimension where the CTA starts.\n        int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n        // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n        int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n        // Compute the NHW coordinate of the thread in the CTA.\n        const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n        // The position in the C dimension where the CTA starts.\n        const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n        // Compute the C coordinate of the thread in the CTA.\n        const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n        // Compute the C coordinate of the thread.\n        const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;\n\n        // Is the thread working on a valid C dimension?\n        const int is_valid_c = thread_c < params.c;\n\n        float mean[ELEMENTS_PER_LDG];\n        zero_array(mean);\n        if (is_valid_c) {\n            read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);\n        }\n\n        // accumulation related registers\n        float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // The number of elements loaded by this CTA.\n        int cta_count = 0;\n        // The base pointers to load from.\n        const uint16_t *gmem_src = &params.gmem_src[thread_c];\n        const uint16_t *gmem_dy = &params.gmem_dy[thread_c];\n        uint16_t *gmem_dst1 = &params.gmem_dst1[thread_c];\n\n        // outer loops\n        int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;\n        // Load the batch of elements. Compute sum across them\n        const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;\n\n        if (OUTER_LOOPS_ != 1) {\n            // We cannot load everything to store persistently, so let's makes sure registers and\n            // smem are fully utilized, offset is evenly divisible by 32\n            int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x -\n                          params.nhw) & ~31;\n            cta_nhw_regs -= offset;\n            cta_nhw_smem -= offset;\n        }\n\n        const unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +\n                                      ((params.nhw + 31) & ~31) * 2 * c_blk_index;\n\n        #pragma unroll 1\n        for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n            // The nhw position.\n            int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;\n            // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n            cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));\n\n            int lane_id = threadIdx.x & 31;\n\n            // Read the elements from memory.\n            float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n            unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                zero_array(x_storage[i]);\n                zero_array(dy_storage[i]);\n                is_valid[i] = 0.f;\n                const bool is_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                if (is_valid_nhw) {\n                    if (is_valid_c) {\n                        if (loop_i == OUTER_LOOPS - 1) {\n                            ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                            ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);\n                        } else {\n                            ldg(x_storage[i], &gmem_src[idx*params.c]);\n                            ldg(dy_storage[i], &gmem_dy[idx*params.c]);\n                        }\n                        is_valid[i] = 1.f;\n                    }\n\n                    if (lane_id < ELEMENTS_PER_LDG) {\n                        relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id];\n                    }\n                }\n            }\n\n            // Do the math.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                // Convert to float and update\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                bool rectified[ELEMENTS_PER_LDG];\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) &\n                                    (1U << lane_id)) != 0);\n                }\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                // Update the count.\n                count += is_valid[i];\n                // Invert the count.\n                float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n                relu_bwd(dy_math, rectified, is_valid[i]);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n\n                // Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version\n                from_float(dy_storage[i], dy_math);\n\n                // dZ for elementwise add\n                if (is_valid[i]) {\n                    if (loop_i == OUTER_LOOPS - 1) {\n                        stg_stream(&gmem_dst1[idx*params.c], dy_storage[i]);\n                    } else {\n                        stg(&gmem_dst1[idx*params.c], dy_storage[i]);\n                    }\n                }\n            }\n        }\n\n        // The elements to load and store in SMEM.\n        int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;\n        // Load elements from SMEM, update the CTA count.\n        int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);\n        if (pixels_in_smem > 0) {\n            cta_count += pixels_in_smem;\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_pixel_valid_nhw =\n                    static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n                const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c;\n                PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                                  dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                unsigned int relu_mask;\n                int lane_id = threadIdx.x & 31;\n                zero_array(x_storage_local);\n                zero_array(dy_storage_local);\n                if (is_pixel_valid_nhw) {\n                    if (is_valid_c) {\n                        ldg_stream(x_storage_local, &gmem_src[idx*params.c]);\n                        ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);\n                    }\n                    if (lane_id < ELEMENTS_PER_LDG) {\n                        relu_mask = gmem_relu_bitmask[idx * 2 + lane_id];\n                    }\n                }\n                bool rectified[ELEMENTS_PER_LDG];\n                #pragma unroll\n                for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n                    rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) &\n                                    (1U << lane_id)) != 0);\n                }\n\n                // The offset to store in SMEM.\n                int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Store in SMEM.\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n                offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                // Update the count.\n                count += is_pixel_valid;\n                // Invert the count.\n                float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage_local);\n                to_float(dy_math, dy_storage_local);\n\n                relu_bwd(dy_math, rectified, is_pixel_valid);\n                bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n\n                from_float(dy_storage_local, dy_math);\n                // dZ for elementwise add\n                if (is_pixel_valid) {\n                    stg_stream(&gmem_dst1[idx*params.c], dy_storage_local);\n                }\n                // only store the 'relu-dgrad'ed version!\n                write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n            }\n        }\n\n        // We scale the mean by the number of elements. It brings more stability.\n        #pragma unroll\n        for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            dbias[i] *= count;\n            dscale[i] *= count;\n        }\n\n        // dscale parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dscale, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n            smem, dbias, thread_in_cta_nhw);\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // The workspace in global memory is distributed across the different CTA.\n        int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;\n        // Write the data for the CTA to global memory.\n        float *gmem_sums = &params.gmem_sums[gmem_sums_offset];\n        if (threadIdx.x < THREADS_PER_PIXEL) {\n            const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;\n            write_to_gmem(&gmem_sums[                           0], idx, dscale);\n            write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);\n        }\n\n        // The counters to count how many CTAs have retired at this point.\n        // A given cta uses the same counter every other time through the outer loop.\n        int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n        inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n        // Reset the accumulators for global summation\n        zero_array(dscale);\n        zero_array(dbias);\n\n        // Build the global accumulation\n        #pragma unroll 1\n        for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {\n            float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n            read_from_gmem(tmp1, gmem_sums,                              idx);\n            read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);\n\n            #pragma unroll\n            for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n                dscale[i] += tmp1[i];\n                dbias[i] += tmp2[i];\n            }\n        }\n\n        // dscale parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dscale, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dscale, smem, thread_in_cta_c);\n        __syncthreads();\n\n        // dbias parallel sum\n        if (params.sync_iters>0) {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);\n        } else {\n            ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(\n                smem, dbias, thread_in_cta_nhw);\n        }\n\n        __syncthreads();\n        // The values in shared memory correspond to the CTA-wide sums.\n        read_from_smem(dbias, smem, thread_in_cta_c);\n\n        // Normalize the dscale.\n        float var[ELEMENTS_PER_LDG];\n        zero_array(var);\n        if (is_valid_c) {\n            read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n\n        // store dscale/dbias\n        bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n        if (is_valid_for_saving) {\n            if (params.sync_iters>0)\n            {\n                scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n                scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n            } else {\n                write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);\n                write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);\n            }\n        }\n\n        // Further normalize the dscale to be used in dx calculation\n        float scale[ELEMENTS_PER_LDG];\n        zero_array(scale);\n        if (is_valid_c) {\n            read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);\n        }\n        multiply(dscale, var);\n        // scale the inv-var as well, afterwards\n        multiply(var, scale);\n\n        // inverse count\n        float inv_count = params.svar_inv_count;\n\n        // The base pointer to write to.\n        uint16_t *const gmem_dst = &params.gmem_dst[thread_c];\n\n        // Store the elements in registers.\n        #pragma unroll 1\n        for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {\n            // The value for nhw.\n            int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;\n\n            // Normalize the elements and write to memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                // Convert to float.\n                float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                to_float(x_math, x_storage[i]);\n                to_float(dy_math, dy_storage[i]);\n\n                float dx[ELEMENTS_PER_LDG];\n                bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                // Write back.\n                if (is_valid) {\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n\n            // The next value of nhw.\n            out_nhw -= pixels_per_iteration;\n\n            // Read the next elements from memory.\n            #pragma unroll\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n                const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                float y[ELEMENTS_PER_LDG];\n                zero_array(y);\n                if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n                    ldg_stream(x_storage[i], &gmem_src[idx*params.c]);\n                    ldg_stream(dy_storage[i], &gmem_dst1[idx*params.c]);\n                }\n            }\n        }\n\n        // Normalize the elements from SMEM and write them out.\n        if (pixels_in_smem > 0) {\n            for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n                const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;\n                const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n                if (is_valid) {\n                    // Read from SMEM.\n                    int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],\n                        dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n                    read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;\n                    read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n                    float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n                    to_float(x_math, x_storage_local);\n                    to_float(dy_math, dy_storage_local);\n\n                    float dx[ELEMENTS_PER_LDG];\n                    bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n                    // Write back.\n                    stg_stream(&gmem_dst[idx*params.c], dx);\n                }\n            }\n        }\n        // We're about to start on the next c-blk.  Needed?\n        __syncthreads();\n    }\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/layer_norm/ln_api.cpp",
    "content": "#include <torch/extension.h>\n#include \"ATen/cuda/CUDAContext.h\"\n\nvoid ln_fwd_cuda(at::Tensor &y, at::Tensor &mu, at::Tensor &rsigma,\n                 const at::Tensor &x, const at::Tensor &gamma,\n                 const at::Tensor &beta, const float epsilon, const int rows, const int cols,\n                 cudaStream_t stream);\n\nvoid ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,\n                 const at::Tensor &dw, const at::Tensor &x,\n                 const at::Tensor &mu, const at::Tensor &rsigma,\n                 const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream);\n\n\nstd::vector<at::Tensor> ln_fwd(const at::Tensor &x,      // BxSxhidden_size\n                               const at::Tensor &gamma,   // hidden_size\n                               const at::Tensor &beta,   // hidden_size\n                               const float epsilon\n) {\n\n    TORCH_CHECK(x.is_cuda())\n    TORCH_CHECK(gamma.is_cuda())\n    TORCH_CHECK(beta.is_cuda())\n\n    TORCH_CHECK(x.is_contiguous());\n    auto sizes = x.sizes();\n    TORCH_CHECK(sizes.size() == 2);\n\n    const int rows = sizes[0];\n    const int cols = sizes[1];\n\n    auto dtype = x.scalar_type();\n\n    TORCH_CHECK(gamma.dtype() == dtype);\n    TORCH_CHECK(beta.dtype() == dtype);\n\n    TORCH_CHECK(gamma.sizes() == beta.sizes());\n    TORCH_CHECK(gamma.numel() == cols);\n\n    TORCH_CHECK(epsilon >= 0.f);\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    auto y = torch::empty_like(x);\n\n    auto opts = x.options();\n\n    auto mu = torch::empty({rows}, opts.dtype(torch::kFloat32));\n    auto rsigma = torch::empty({rows}, opts.dtype(torch::kFloat32));\n\n    ln_fwd_cuda(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, stream);\n\n    return {y, mu, rsigma};\n}\n\n\n\nstd::vector<at::Tensor> ln_bwd(const at::Tensor &dw,     // BxSxhidden_size\n                               const at::Tensor &x,      // BxSxhidden_size\n                               const at::Tensor &mu,     // BxS, FP32!\n                               const at::Tensor &rsigma, // BxS, FP32!\n                               const at::Tensor &gamma   // hidden_size\n) {\n\n  TORCH_CHECK(x.is_cuda());\n  TORCH_CHECK(dw.is_cuda());\n  TORCH_CHECK(mu.is_cuda());\n  TORCH_CHECK(rsigma.is_cuda());\n  TORCH_CHECK(gamma.is_cuda());\n\n  TORCH_CHECK(x.is_contiguous());\n  TORCH_CHECK(dw.is_contiguous());\n\n  auto sizes = x.sizes();\n  TORCH_CHECK(sizes.size() == 2);\n  TORCH_CHECK(dw.sizes() == sizes);\n  auto rows = sizes[0];\n  auto cols = sizes[1];\n  \n  auto dtype = x.scalar_type();\n  TORCH_CHECK(dw.dtype() == dtype);\n  TORCH_CHECK(gamma.dtype() == dtype);\n  TORCH_CHECK(mu.dtype() == torch::kFloat32);\n  TORCH_CHECK(rsigma.dtype() == torch::kFloat32);\n  TORCH_CHECK(mu.sizes() == rsigma.sizes());\n  TORCH_CHECK(mu.numel() == rows);\n\n  TORCH_CHECK(gamma.numel() == cols);\n\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n  auto dx = torch::empty_like(x);\n  auto dgamma = torch::empty_like(gamma);\n  auto dbeta = torch::empty_like(gamma);\n  \n  ln_bwd_cuda(dx, dgamma, dbeta, dw, x, mu, rsigma, gamma, rows, cols, stream);\n\n  return {dx, dgamma, dbeta};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.doc() = \"CUDA LayerNorm\"; // optional module docstring\n  m.def(\"ln_fwd\", &ln_fwd, \"Run LayerNorm forward kernel\");\n  m.def(\"ln_bwd\", &ln_bwd, \"Run LayerNorm backward kernel\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu",
    "content": "#include \"utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n\ntemplate<typename Ktraits>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(void * __restrict__ dx_,\n                                                                          void * __restrict__ dg_,\n                                                                          void * __restrict__ db_,\n                                                                          const void * __restrict__ dw_,\n                                                                          const void * __restrict__ x_,\n                                                                          const void * __restrict__ mu_,\n                                                                          const void * __restrict__ rs_,\n                                                                          const void * __restrict__ g_,\n                                                                          const int rows\n                                                                        ){\n  using Vec = typename Ktraits::Vec;\n\n  enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };\n  enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n  enum { WARPS_M = Ktraits::WARPS_M };\n  enum { WARPS_N = Ktraits::WARPS_N };\n  enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n  enum { COLS = Ktraits::COLS };\n  enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n  enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };\n  static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, \"\");\n  enum { NUM_ELTS = Vec::NUM_ELTS };\n  using vec_t = typename Ktraits::vec_t;\n  using base_t = typename Ktraits::base_t;\n  using compute_t = typename Ktraits::compute_t;\n  const int tidx = threadIdx.x;\n  const int bidx = blockIdx.x;\n  const int lane = tidx % THREADS_PER_WARP;\n  const int warp = tidx / THREADS_PER_WARP;\n  const int warp_m = warp / Ktraits::WARPS_N;\n  const int warp_n = warp % Ktraits::WARPS_N;\n  const int tid_r = warp_n * THREADS_PER_WARP + lane;\n\n  const int r = bidx * Ktraits::ROWS_PER_CTA + warp_m;\n  const int c = warp_n * THREADS_PER_WARP + lane;\n\n  const char *dw_ptr = static_cast<const char *>(dw_);\n  const char *x_ptr = static_cast<const char *>(x_);\n  const char *g_ptr = static_cast<const char *>(g_);\n  char *dx_ptr = static_cast<char *>(dx_);\n  const compute_t *mu_ptr = static_cast<const compute_t *>(mu_);\n  const compute_t *rs_ptr = static_cast<const compute_t *>(rs_);\n  static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS, \"\");\n\n  // smem for final reduction\n  //__shared__ compute_t smem_[ROWS_PER_CTA * COLS];\n  extern __shared__ compute_t smem_[];\n  // static_assert(sizeof(smem_dw_sum) == 32*1024,\"\");\n  // Using the grid stride loop we can assign multiple rows to each thread\n  // by using a number of CTAs smaller than rows / ROWS_PER_CTA\n  // We accumulate them here, one in smem, one in registers, because the smem\n  // capacity is limited compute_t * dw_sum = &smem_dw_sum[warp_m * COLS + tid_r\n  // * LDGS * NUM_ELTS];\n  compute_t dwy_sum[LDGS * NUM_ELTS];\n  compute_t dw_sum[LDGS * NUM_ELTS];\n\n  memset(dwy_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);\n  memset(dw_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);\n  // Debug 8 rows, 4B, 1024 cols\n\n  __shared__ compute_t smem_mdy[ROWS_PER_CTA * WARPS_N];\n  __shared__ compute_t smem_mdyy[ROWS_PER_CTA * WARPS_N];\n  compute_t *mdy_shared = &smem_mdy[warp_m * WARPS_N];\n  compute_t *mdyy_shared = &smem_mdyy[warp_m * WARPS_N];\n\n  constexpr float rn = 1.f / float(COLS);\n  Vec gamma[LDGS];\n  int col = c;\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n    gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);\n    col += Ktraits::THREADS_PER_ROW;\n  }\n  // TODO if ROWS_PER_CTA does not divice rows, we might get divergence in the\n  // last blocks with syncthreads!\n  // grid stride over rows\n  #pragma unroll 1\n  for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {\n    const compute_t mu_r = mu_ptr[row];\n    const compute_t rs_r = rs_ptr[row];\n    Vec dw[LDGS], x[LDGS], dx[LDGS];\n    int col = c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dw[it].load_from(dw_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += THREADS_PER_ROW;\n    }\n    // local reductions\n    compute_t dy[LDGS * NUM_ELTS];\n    compute_t y[LDGS * NUM_ELTS];\n\n    compute_t mdy_local = 0.f;\n    compute_t mdyy_local = 0.f;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < Vec::NUM_ELTS; jt++) {\n        compute_t x_tmp = x[it].data.elt[jt];\n        compute_t y_tmp = rs_r * (x_tmp - mu_r);\n        compute_t dy_tmp = gamma[it].data.elt[jt] * dw[it].data.elt[jt];\n        compute_t dw_tmp = dw[it].data.elt[jt];\n\n        mdy_local += dy_tmp;\n        mdyy_local += dy_tmp * y_tmp;\n\n        dy[it * NUM_ELTS + jt] = dy_tmp;\n        y[it * NUM_ELTS + jt] = y_tmp;\n\n        dwy_sum[it * NUM_ELTS + jt] += dw_tmp * y_tmp;\n        dw_sum[it * NUM_ELTS + jt] += dw_tmp;\n      }\n    }\n\n    // reduction across row for mdy, mdyy\n    if (WARPS_N == 1) { // no need to go through smem!\n#pragma unroll\n      for (int it = 1; it < THREADS_PER_WARP; it *= 2) {\n        mdy_local += __shfl_xor_sync(uint32_t(-1), mdy_local, it);\n        mdyy_local += __shfl_xor_sync(uint32_t(-1), mdyy_local, it);\n      }\n\n      mdy_local *= rn;\n      mdyy_local *= rn;\n\n    } else {\n\n#pragma unroll\n      for (int it = 16; it > 0; it /= 2) {\n        mdy_local += __shfl_down_sync(uint32_t(-1), mdy_local, it);\n        mdyy_local += __shfl_down_sync(uint32_t(-1), mdyy_local, it);\n      } // lane 0 holds the result!\n\n      if (lane == 0) {\n        mdy_shared[warp_n] = mdy_local;\n        mdyy_shared[warp_n] = mdyy_local;\n      }\n\n      __syncthreads();\n      if (warp_n == 0 && lane == 0) {\n        mdy_local = 0.f;\n        mdyy_local = 0.f;\n        for (int it = 0; it < WARPS_N; it++) {\n          mdy_local += mdy_shared[it];\n          mdyy_local += mdyy_shared[it];\n        }\n        mdy_shared[0] = mdy_local;\n        mdyy_shared[0] = mdyy_local;\n      }\n      __syncthreads();\n\n      mdy_local = mdy_shared[0] * rn;\n      mdyy_local = mdyy_shared[0] * rn;\n    }\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        compute_t dy_tmp = dy[it * NUM_ELTS + jt];\n        compute_t y_tmp = y[it * NUM_ELTS + jt];\n        compute_t dx_tmp =\n            compute_t(rs_r) * (dy_tmp - mdyy_local * y_tmp - mdy_local);\n        dx[it].data.elt[jt] = dx_tmp;\n      }\n    }\n\n    col = c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dx[it].store_to(dx_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += Ktraits::THREADS_PER_ROW;\n    }\n\n  } // end: grid stride loop\n\n  // Finalize reduction of part dgamma and dbeta for this CTA\n  // by reducing over the rows held across the WARPS_M warps\n\n  enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };\n  static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, \"\");\n\n  compute_t *smem_write;\n\n  smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n    for (int jt = 0; jt < NUM_ELTS; jt++) {\n      smem_write[jt] = dw_sum[it * NUM_ELTS + jt];\n    }\n    smem_write += THREADS_PER_ROW * NUM_ELTS;\n  }\n  __syncthreads();\n  compute_t cta_dw_sum[NUM_RES];\n  memset(cta_dw_sum, 0, sizeof(compute_t) * NUM_RES);\n  for (int it = 0; it < ROWS_PER_CTA; it++) {\n    for (int jt = 0; jt < NUM_RES; jt++) {\n      cta_dw_sum[jt] += smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n    }\n  }\n  __syncthreads();\n\n  smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n    for (int jt = 0; jt < NUM_ELTS; jt++) {\n      smem_write[jt] = dwy_sum[it * NUM_ELTS + jt];\n    }\n    smem_write += THREADS_PER_ROW * NUM_ELTS;\n  }\n  __syncthreads();\n  compute_t cta_dwy_sum[NUM_RES];\n  memset(cta_dwy_sum, 0, sizeof(compute_t) * NUM_RES);\n  for (int it = 0; it < ROWS_PER_CTA; it++) {\n    for (int jt = 0; jt < NUM_RES; jt++) {\n      cta_dwy_sum[jt] +=\n          smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n    }\n  }\n\n  compute_t *dgamma_part = static_cast<compute_t *>(dg_) + bidx * COLS + tidx;\n  for (int jt = 0; jt < NUM_RES; jt++) {\n    *dgamma_part = cta_dwy_sum[jt];\n    dgamma_part += Ktraits::THREADS_PER_CTA;\n  }\n\n  compute_t *dbeta_part = static_cast<compute_t *>(db_) + bidx * COLS + tidx;\n  for (int jt = 0; jt < NUM_RES; jt++) {\n    *dbeta_part = cta_dw_sum[jt];\n    dbeta_part += Ktraits::THREADS_PER_CTA;\n  }\n}\n\ntemplate<typename Ktraits, typename out_t>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(void * __restrict__ dg_,\n                                                                                   void * __restrict__ db_,\n                                                                                   const void * __restrict__ dg_part_,\n                                                                                   const void * __restrict__ db_part_,\n                                                                                   const int rows\n                                                                                  ){\n    using Vec = typename Ktraits::Vec;\n    enum { NUM_ELTS = Vec::NUM_ELTS };\n\n\n    using vec_t = typename Ktraits::vec_t;\n    using base_t = typename Ktraits::base_t;\n    using compute_t = typename Ktraits::compute_t;\n\n    enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };\n    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n    enum { WARPS_M = Ktraits::WARPS_M };\n    enum { WARPS_N = Ktraits::WARPS_N };\n    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n    enum { COLS = Ktraits::COLS };\n    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n    enum {VEC_COLS = BYTES_PER_ROW / BYTES_PER_LDG};\n    //dbg\n    static_assert(VEC_COLS == COLS / NUM_ELTS, \"\"); \n    //static_assert(VEC_COLS == 1024,\"\");\n    const int tidx = threadIdx.x;\n    const int bidx = blockIdx.x;\n    const int lane = tidx % THREADS_PER_WARP;\n    const int warp = tidx / THREADS_PER_WARP;\n    const int warp_m = warp / Ktraits::WARPS_N;\n    const int warp_n = warp % Ktraits::WARPS_N;\n    const int tid_c = warp_n * THREADS_PER_WARP + lane;\n    const int c =bidx * THREADS_PER_ROW + tid_c;\n    const int r = warp_m;\n    \n    __shared__ compute_t smem_[(WARPS_M - 1) * THREADS_PER_ROW * NUM_ELTS];\n    \n    //Will probably run this with WARPS_N = 1 and grid = 1024 / (32*4) = 8, or NUM_ELTS=1 and grid = 32 \n    // and WARPS_M = 4 (or 1??)\n    for(int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW){\n      const char* dg_part_ptr = static_cast<const char*>(dg_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;\n      const char* db_part_ptr = static_cast<const char*>(db_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;\n\n      compute_t dg_sum[NUM_ELTS];\n      compute_t db_sum[NUM_ELTS];\n      memset(dg_sum, 0, sizeof(compute_t) * NUM_ELTS);\n      memset(db_sum, 0, sizeof(compute_t) * NUM_ELTS);\n      #pragma unroll\n      for(int row = r; row < rows;row += ROWS_PER_CTA){\n        Vec dg;\n        Vec db;\n        dg.load_from(dg_part_ptr);\n        db.load_from(db_part_ptr);\n        dg_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;\n        db_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;\n\n        #pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          dg_sum[jt] += dg.data.elt[jt];\n          db_sum[jt] += db.data.elt[jt];\n        }\n      }\n\n      // Finalize the reduction across rows of the CTA\n      compute_t * smem_write;\n      smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;\n\n      if (warp_m > 0) {\n#pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          *smem_write = dg_sum[jt];\n          smem_write+=THREADS_PER_ROW;\n        }\n      }\n      __syncthreads();\n      compute_t *smem_read ;\n      smem_read = smem_ + tid_c ;\n      if (warp_m == 0) {\n#pragma unroll\n        for (int it = 0; it < WARPS_M - 1; it++) {\n#pragma unroll\n          for (int jt = 0; jt < NUM_ELTS; jt++) {\n            dg_sum[jt] += *smem_read;\n            smem_read += THREADS_PER_ROW;\n          }\n        }\n      }\n\n      __syncthreads();\n\n      smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;\n\n      if (warp_m > 0) {\n#pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          *smem_write = db_sum[jt];\n          smem_write+=THREADS_PER_ROW;\n        }\n      }\n      __syncthreads();\n      smem_read = smem_ + tid_c;\n      if (warp_m == 0) {\n#pragma unroll\n        for (int it = 0; it < WARPS_M - 1; it++) {\n#pragma unroll\n          for (int jt = 0; jt < NUM_ELTS; jt++) {\n            db_sum[jt] += *smem_read;\n            smem_read += THREADS_PER_ROW;\n          }\n        }\n\n        using vout_t = typename Vec_type<sizeof(out_t) * NUM_ELTS>::Type;\n        union {\n          vout_t raw;\n          out_t elt[NUM_ELTS];\n        } dg_out, db_out;\n\n        // out_t dg_out[NUM_ELTS], db_out[NUM_ELTS];\n#pragma unroll\n        for (int jt = 0; jt < NUM_ELTS; jt++) {\n          dg_out.elt[jt] = dg_sum[jt];\n          db_out.elt[jt] = db_sum[jt];\n        }\n        vout_t *dg_ptr = reinterpret_cast<vout_t *>(dg_) + col ;\n        vout_t *db_ptr = reinterpret_cast<vout_t *>(db_) + col ;\n        *dg_ptr = dg_out.raw;\n        *db_ptr = db_out.raw;\n      }\n    }\n}\n\ntemplate<typename scalar_t>\nvoid launch(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,\n                 at::Tensor &dgamma_part, at::Tensor &dbeta_part,\n                 const at::Tensor &dw, const at::Tensor &x,\n                 const at::Tensor &mu, const at::Tensor &rsigma,\n                 const at::Tensor &gamma, const int rows, const int cols, const int gridx, cudaStream_t stream){\n\n  if (cols == 1024) {\n    using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;\n\n    if (Ktraits::SMEM_BYTES >= 48 * 1024) {\n      AT_CUDA_CHECK(cudaFuncSetAttribute(\n          ln_bwd_kernel<Ktraits>, cudaFuncAttributeMaxDynamicSharedMemorySize,\n          Ktraits::SMEM_BYTES));\n    }\n\n    ln_bwd_kernel<Ktraits>\n        <<<gridx, Ktraits::THREADS_PER_CTA, Ktraits::SMEM_BYTES, stream>>>(\n            dx.data_ptr(), dgamma_part.data_ptr(), dbeta_part.data_ptr(),\n            dw.data_ptr(), x.data_ptr(), mu.data_ptr(), rsigma.data_ptr(),\n            gamma.data_ptr(), rows);\n\n    using Ktraits2 = Kernel_traits<float, 1024, 16, 1, 4>;\n\n    constexpr int grid2 =\n        DIVUP(1024, Ktraits2::THREADS_PER_ROW * Ktraits2::Vec::NUM_ELTS);\n\n    ln_bwd_finalize_kernel<Ktraits2, scalar_t>\n        <<<grid2, Ktraits2::THREADS_PER_CTA, 0, stream>>>(\n            dgamma.data_ptr(), dbeta.data_ptr(), dgamma_part.data_ptr(),\n            dbeta_part.data_ptr(), gridx);\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n  AT_CUDA_CHECK(cudaPeekAtLastError());\n}\n\nvoid ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,\n                 const at::Tensor &dw, const at::Tensor &x,\n                 const at::Tensor &mu, const at::Tensor &rsigma,\n                 const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream) {\n\n\n  const auto dtype = x.scalar_type();\n\n\n  const auto props = at::cuda::getCurrentDeviceProperties();\n  const int smCount = props->multiProcessorCount;\n  // Launch 2 CTAs per SM \n  const int grid = 2 * smCount;\n\n  //request workspace for two-step reduction. We always reduce in FP32.\n  auto opts = x.options();\n  auto dbeta_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));\n  auto dgamma_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));\n\n  if (dtype == torch::kFloat16) {\n    launch<half>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);\n  } else if (dtype == torch::kFloat32) {\n    launch<float>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n}"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu",
    "content": "#include \"utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n\ntemplate <typename Ktraits>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(\n    void *__restrict__ y_, void *__restrict__ mu_, void *__restrict__ rsigma_,\n    const void *__restrict__ x_, const void *__restrict__ gamma_,\n    const void *__restrict__ beta_, const float epsilon, int rows) {\n\n  using Vec = typename Ktraits::Vec;\n\n  using base_t = typename Ktraits::base_t;\n  using compute_t = typename Ktraits::compute_t;\n  enum { NUM_ELTS = Vec::NUM_ELTS };\n  enum { WARPS_N = Ktraits::WARPS_N };\n  enum { WARPS_M = Ktraits::WARPS_M };\n  enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n\n  enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n  enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };\n  static_assert(BYTES_PER_LDG == 16, \"\");\n\n  enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n  enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };\n  static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, \"\");\n\n  const int tidx = threadIdx.x;\n  const int bidx = blockIdx.x;\n  const int lane = tidx % THREADS_PER_WARP;\n  const int warp = tidx / THREADS_PER_WARP;\n  const int warp_n = warp % WARPS_N;\n  const int warp_m = warp / WARPS_N;\n\n  const int c = warp_n * THREADS_PER_WARP + lane;\n  const int r = bidx * ROWS_PER_CTA + warp_m;\n\n  const char *x_ptr = static_cast<const char *>(x_);\n\n  const char *g_ptr = static_cast<const char *>(gamma_);\n  const char *b_ptr = static_cast<const char *>(beta_);\n\n  char *y_ptr = static_cast<char *>(y_);\n  compute_t *mu_ptr = static_cast<compute_t *>(mu_);\n  compute_t *rs_ptr = static_cast<compute_t *>(rsigma_);\n\n  Vec gamma[LDGS];\n  Vec beta[LDGS];\n#pragma unroll\n  for (int it = 0, col = c; it < LDGS; it++) {\n    gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);\n    beta[it].load_from(b_ptr + col * BYTES_PER_LDG);\n    col += THREADS_PER_ROW;\n  }\n\n  constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);\n  for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {\n    Vec x[LDGS];\n#pragma unroll\n    for (int it = 0, col = c; it < LDGS; it++) {\n      x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += THREADS_PER_ROW;\n    }\n    compute_t xf[LDGS * NUM_ELTS];\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        xf[it * NUM_ELTS + jt] = compute_t(x[it].data.elt[jt]);\n      }\n    }\n\n    compute_t mu_local = 0.f;\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        mu_local += xf[it * NUM_ELTS + jt];\n      }\n    }\n\n#pragma unroll\n    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {\n      mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);\n    }\n    mu_local *= rn;\n    if(lane == 0){\n    mu_ptr[row] = mu_local;\n    }\n    compute_t var_local = 0.f;\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        compute_t diff = xf[it * NUM_ELTS + jt] - mu_local;\n        var_local += diff * diff;\n      }\n    }\n\n#pragma unroll\n    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {\n      var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);\n    }\n    compute_t rsigma = rsqrtf(var_local * rn + epsilon);\n    if(lane == 0){\n    rs_ptr[row] = rsigma;\n    }\n\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        base_t tmp = (rsigma * (xf[it * NUM_ELTS + jt] - mu_local));\n        x[it].data.elt[jt] = gamma[it].data.elt[jt] *  tmp + beta[it].data.elt[jt];\n      }\n    }\n\n#pragma unroll\n    for (int it = 0, col = c; it < LDGS; it++) {\n      x[it].store_to(y_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);\n      col += THREADS_PER_ROW;\n    }\n  }\n}\ntemplate<typename scalar_t>\nvoid launch(\n    at::Tensor & y, // BxSxhidden_size\n    at::Tensor & mu,\n    at::Tensor & rsigma,\n    const at::Tensor & x, // BxSxhidden_size\n    const at::Tensor & gamma,\n    const at::Tensor & beta,\n    const float epsilon,\n    const int rows,\n    const int cols,\n    const int max_gridx,\n    cudaStream_t stream\n){\n\n  if (cols == 1024) {\n    using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;\n    const int grid =\n        std::min<int>(DIVUP(rows, Ktraits::ROWS_PER_CTA), max_gridx);\n\n    ln_fwd_kernel<Ktraits><<<grid, Ktraits::THREADS_PER_CTA, 0, stream>>>(\n        y.data_ptr(), mu.data_ptr(), rsigma.data_ptr(), x.data_ptr(),\n        gamma.data_ptr(), beta.data_ptr(), epsilon, rows);\n\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n  AT_CUDA_CHECK(cudaPeekAtLastError());\n}\n\nvoid ln_fwd_cuda(\n    at::Tensor & y, // BxSxhidden_size\n    at::Tensor & mu,\n    at::Tensor & rsigma,\n    const at::Tensor & x, // BxSxhidden_size\n    const at::Tensor & gamma,\n    const at::Tensor & beta,\n    const float epsilon,\n    const int rows, const int cols,\n    cudaStream_t stream\n){\n\n  const auto dtype = x.scalar_type();\n  const auto props = at::cuda::getCurrentDeviceProperties();\n  const int max_gridx = props->maxGridSize[0];\n\n  //TODO \n  // - Using dispatch macro costs 1% perf wtf?!?!\n  // - Tune FP32 warps\n  // - Add more sizes\n  if (dtype == torch::kFloat16) {\n    launch<half>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);\n  } else if (dtype == torch::kFloat32) {\n    launch<float>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);\n  } else {\n    assert(false && \"Not implemented\");\n  }\n\n}"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/layer_norm/ln_kernel_traits.h",
    "content": "#pragma once\n\nconstexpr uint32_t THREADS_PER_WARP = 32;\n\ntemplate <typename dtype, int COLS_, int WARPS_M_, int WARPS_N_,\n          int BYTES_PER_LDG_ = 16>\nstruct Kernel_traits {\n  enum { WARPS_M = WARPS_M_ };\n  enum { WARPS_N = WARPS_N_ };\n  enum { COLS = COLS_ };\n  enum { BYTES_PER_LDG = BYTES_PER_LDG_ };\n\n  using Vec = Vec<dtype, BYTES_PER_LDG>;\n\n  using vec_t = typename Vec::vec_t;\n  using base_t = typename Vec::base_t;\n  using packed_t = typename Vec::packed_t;\n  using compute_t = typename Vec::compute_t;\n  using packed_compute_t = typename Vec::packed_compute_t;\n\n  enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };\n  enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };\n  enum { ROWS_PER_CTA = WARPS_M };\n\n  enum { BYTES_PER_ROW = COLS * sizeof(base_t) };\n  enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };\n  enum {SMEM_BYTES = ROWS_PER_CTA * COLS * sizeof(compute_t)};\n};\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/layer_norm/utils.cuh",
    "content": "#pragma once\n\n#include \"torch/extension.h\"\n#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK\n\n#define DIVUP(x, y) (((x) + ((y)-1)) / (y))\n\n#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...)                               \\\n  [&] {                                                                        \\\n    const auto &the_type = TYPE;                                               \\\n    /* don't use TYPE again in case it is an expensive or side-effect op */    \\\n    at::ScalarType _st = ::detail::scalar_type(the_type);                      \\\n    switch (_st) {                                                             \\\n      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)          \\\n      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)        \\\n    default:                                                                   \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(_st), \"'\");           \\\n    }                                                                          \\\n  }()\n\ntemplate <int Bytes> struct Vec_type {};\n\ntemplate <> struct Vec_type<16> {\n  using Type = uint4;\n  static __device__ inline Type zero() { return make_uint4(0, 0, 0, 0); }\n};\ntemplate <> struct Vec_type<8> {\n  using Type = uint2;\n  static __device__ inline Type zero() { return make_uint2(0, 0); }\n};\n\ntemplate <> struct Vec_type<4> {\n  using Type = uint32_t;\n  static __device__ inline Type zero() { return 0; }\n};\n\ntemplate <> struct Vec_type<2> {\n  using Type = uint16_t;\n  static __device__ inline Type zero() { return 0; }\n};\n\ntemplate <typename T> struct TypeInfo {\n  using base_t = T;\n  using packed_t = T;\n  using compute_t = float;\n  using packed_compute_t = float;\n};\n\ntemplate <> struct TypeInfo<half> {\n  using base_t = half;\n  using packed_t = half2;\n  using compute_t = float;\n  using packed_compute_t = float2;\n};\n\ntemplate <typename dtype, int Bytes> struct Vec {\n\n  using base_t = typename TypeInfo<dtype>::base_t;\n  using packed_t = typename TypeInfo<dtype>::packed_t;\n  using compute_t = typename TypeInfo<dtype>::compute_t;\n  using packed_compute_t = typename TypeInfo<dtype>::packed_compute_t;\n\n  static_assert(Bytes % sizeof(base_t) == 0, \"\");\n  static_assert(Bytes % sizeof(packed_t) == 0, \"\");\n  enum { BYTES_PER_THREAD = Bytes };\n  enum { NUM_ELTS = Bytes / sizeof(base_t) };\n  enum { NUM_PACKED = Bytes / sizeof(packed_t) };\n  using vec_t = typename Vec_type<Bytes>::Type;\n  using store_t = union {\n    vec_t raw;\n    base_t elt[NUM_ELTS];\n    packed_t packed[NUM_PACKED];\n  };\n  store_t data;\n\n  __device__ Vec() { data.raw = Vec_type<Bytes>::zero(); }\n\n  __device__ inline void load_from(const char *ptr) {\n    data.raw = *reinterpret_cast<const vec_t *>(ptr);\n  }\n\n  __device__ inline void load_or_zero(const char *ptr, const bool is_valid) {\n    data.raw = is_valid ? *reinterpret_cast<const vec_t *>(ptr)\n                        : Vec_type<Bytes>::zero();\n  }\n\n  __device__ inline void store_to(char *ptr) const {\n    *reinterpret_cast<vec_t *>(ptr) = data.raw;\n  }\n\n  __device__ inline void store_valid(char *ptr, const bool is_valid) const {\n    if (is_valid)\n      *reinterpret_cast<vec_t *>(ptr) = data.raw;\n  }\n};\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp",
    "content": "#include <torch/extension.h>\n#include <cuda_fp16.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace additive_mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const half*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t       bool \t\t\t\tuse_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(input.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Half, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(\n                                 is_training,\n                                 heads, \n                                 input, \n                                 use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\ntorch::Tensor bwd(\n\t\t               bool use_mask,\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n//  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n\t\t                 heads,\n                                 output_grads,\n                                 softmax_results, \n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace mask_softmax_dropout\n} // end namespace fused_softmax\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, \"Self Multihead Attention masked softmax dropout -- Forward.\");\n  m.def(\"backward\", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, \"Self Multihead Attention masked softmax dropout -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"softmax.h\"\n#include \"dropout.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace additive_mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n\t\t\t       bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const half*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = input.size(0);\n  const int   sequences      = attn_batches / heads;\n  const int   q_seq_len      = input.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = input.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n      softmax_success = dispatch_additive_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n  }\n\n\n  if (is_training) {\n    //use at:: function so that C++ version generates the same random mask as python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n\n  return {\n           dropout_results,  \n           dropout_mask, \n           softmax_results\n         };\n}\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results, \n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = output_grads.size(0);\n  const int   q_seq_len      = output_grads.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n//  torch::Tensor input_grads         = torch::empty_like(output_grads);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(\n                             static_cast<half*>(output_grads.data_ptr()), \n                             static_cast<half*>(output_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len, stream);\n//backward pass is completely in-place\n  return output_grads;\n}\n}\n}\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/dropout.h",
    "content": "#include <ATen/ATen.h>\n\n#ifdef OLD_GENERATOR\n#include <ATen/CUDAGenerator.h>\n#else\n#include <ATen/CUDAGeneratorImpl.h>\n#endif\n\n#include <ATen/cuda/CUDAContext.h>\n#include <curand_kernel.h>\n\n#include <THC/THCGeneral.h>\n\nconst int UNROLL = 4;\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\n__global__ void apex_fused_dropout_kernel(scalar_t const                *inputs,\n                                          scalar_t                      *outputs,\n                                          uint8_t                       *mask,\n                                          IndexType                      totalElements, \n\t\t                                  accscalar_t                    p, \n\t\t                                  std::pair<uint64_t, uint64_t>  seeds\n                                         ) \n{\n  accscalar_t pinv = accscalar_t(1)/p;\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n  curandStatePhilox4_32_10_t state;\n  curand_init(\n      seeds.first,\n      idx,\n      seeds.second,\n      &state);\n\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) {\n       float4 rand = curand_uniform4(&state);\n       scalar_t src[UNROLL];\n       rand.x = rand.x <= p;\n       rand.y = rand.y <= p;\n       rand.z = rand.z <= p;\n       rand.w = rand.w <= p;\n\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii] = inputs[li];\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n\t           outputs[li] = src[ii]*(&rand.x)[ii]*pinv;\n               mask[li]    = (uint8_t)(&rand.x)[ii];\n           }\n       }\n       __syncthreads();\n  }\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\n__global__ void apex_dropout_add_kernel(scalar_t const                *inputs,\n                                        scalar_t const                *add_inputs,\n                                        scalar_t                      *outputs,\n                                        uint8_t                       *mask,\n                                        IndexType                      totalElements, \n\t\t                                accscalar_t                    p, \n\t\t                                std::pair<uint64_t, uint64_t>  seeds\n                                       ) \n{\n  accscalar_t pinv = accscalar_t(1)/p;\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n  curandStatePhilox4_32_10_t state;\n  curand_init(\n      seeds.first,\n      idx,\n      seeds.second,\n      &state);\n\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) {\n       float4 rand = curand_uniform4(&state);\n       scalar_t src[UNROLL];\n       scalar_t add_src[UNROLL];\n       rand.x = rand.x <= p;\n       rand.y = rand.y <= p;\n       rand.z = rand.z <= p;\n       rand.w = rand.w <= p;\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii]     = inputs[li];\n               add_src[ii] = add_inputs[li];\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n\t           accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;\n\t           outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);\n               mask[li]    = (uint8_t)(&rand.x)[ii];\n           }\n       }\n       __syncthreads();\n  }\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\n__global__ void apex_add_kernel(          scalar_t const                *inputs,\n                                        scalar_t const                *add_inputs,\n                                        scalar_t                      *outputs,\n                                        IndexType                      totalElements\n                             ) \n{\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) {\n       scalar_t src[UNROLL];\n       scalar_t add_src[UNROLL];\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii]     = inputs[li];\n               add_src[ii] = add_inputs[li];\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n\t           outputs[li] = src[ii] + add_src[ii];\n           }\n       }\n       __syncthreads();\n  }\n}\n\ntemplate<typename scalar_t, \n\t\t typename accscalar_t, \n\t\t typename IndexType\n\t\t>\n__global__ void apex_masked_scale_kernel(scalar_t const *inputs, \n                                         scalar_t       *outputs, \n                                         uint8_t const  *mask, \n                                         IndexType       totalElements,\n                                         accscalar_t     scale\n                                        )\n{\n  IndexType idx          = blockIdx.x * blockDim.x + threadIdx.x;\n  IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx;\n       linearIndex < rounded_size;\n       linearIndex += gridDim.x * blockDim.x*UNROLL) \n  {\n       scalar_t src[UNROLL];\n       scalar_t msk[UNROLL];\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               src[ii] = static_cast<scalar_t>(inputs[li]);\n               msk[ii] = static_cast<scalar_t>(mask[li]);\n           }\n       }\n       for (int ii = 0; ii < UNROLL; ii++) {\n           IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n           if (li < totalElements) {\n               outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);\n           }\n       }\n  }\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\nvoid apex_fused_dropout_cuda(scalar_t const *inputs,\n                           scalar_t       *outputs,\n                           uint8_t        *mask,\n                           IndexType       totalElements, \n\t\t                   accscalar_t     p)\n{\n  auto gen = at::cuda::detail::getDefaultCUDAGenerator();\n  \n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  //number of times random will be generated per thread, to offset philox counter in thc random state\n  int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;\n  std::pair<uint64_t, uint64_t> rng_engine_inputs;\n  {\n    // See Note [Acquire lock when using random generators]\n#ifdef OLD_GENERATOR\n    std::lock_guard<std::mutex> lock(gen->mutex_);\n    rng_engine_inputs = gen->philox_engine_inputs(counter_offset);\n#else\n    std::lock_guard<std::mutex> lock(gen.mutex());\n    rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);\n#endif\n  }\n\n  apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs);\n  THCudaCheck(cudaGetLastError());\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\nvoid apex_dropout_add_cuda(scalar_t const *inputs,\n                           scalar_t const *add_inputs,\n                           scalar_t       *outputs,\n                           uint8_t        *mask,\n                           IndexType       totalElements, \n\t\t                   accscalar_t     p)\n{\n  auto gen = at::cuda::detail::getDefaultCUDAGenerator();\n  \n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  //number of times random will be generated per thread, to offset philox counter in thc random state\n  int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;\n  std::pair<uint64_t, uint64_t> rng_engine_inputs;\n  {\n    // See Note [Acquire lock when using random generators]\n#ifdef OLD_GENERATOR\n    std::lock_guard<std::mutex> lock(gen->mutex_);\n    rng_engine_inputs = gen->philox_engine_inputs(counter_offset);\n#else\n    std::lock_guard<std::mutex> lock(gen.mutex());\n    rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);\n#endif\n  }\n\n  apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);\n  THCudaCheck(cudaGetLastError());\n}\n\ntemplate <\n          typename scalar_t,\n          typename accscalar_t,\n          typename IndexType\n         >\nvoid apex_add_cuda(scalar_t const *inputs,\n                   scalar_t const *add_inputs,\n                   scalar_t       *outputs,\n                   IndexType       totalElements\n\t\t          )\n{\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements);\n  THCudaCheck(cudaGetLastError());\n}\n\ntemplate<typename scalar_t, \n         typename accscalar_t, \n         typename IndexType\n        >\nvoid apex_masked_scale_cuda(scalar_t const *inputs, \n                          scalar_t       *outputs, \n                          uint8_t const  *mask, \n                          IndexType       totalElements,\n                          accscalar_t     scale\n                         )\n{\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size -1)/block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale);\n  THCudaCheck(cudaGetLastError());\n}\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace encdec {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n                               bool                 use_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs_q.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()        == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights_q.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim() == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()   == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs_q.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()        == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  \n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs_q, \n                                 inputs_kv, \n                                 input_weights_q, \n                                 input_weights_kv, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_q_results.dim()  == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_kv_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_q.dim()             == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights_q.dim()      == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()       == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()         == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_q_results.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_q.type().scalarType()             == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()       == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()         == at::ScalarType::Byte, \"Only BYTE is supported\");\n  \n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_q_results, \n                                 input_lin_kv_results, \n                                 inputs_q, \n                                 inputs_kv, \n                                 input_weights_q,\n                                 input_weights_kv,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec \n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::encdec::cublas_gemmex::fwd, \"Encdec Multihead Attention Forward.\");\n  m.def(\"backward\", &multihead_attn::encdec::cublas_gemmex::bwd, \"Encdec Multihead Attention Backward.\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace encdec {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs_q.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_q_results  = torch::empty({q_seq_len, sequences, output_lin_q_dim},  act_options);\n  torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);\n  torch::Tensor softmax_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_mask         = torch::empty({attn_batches, q_seq_len, k_seq_len},      mask_options);\n  torch::Tensor matmul2_results      = torch::empty({q_seq_len, attn_batches, head_dim},       act_options);\n  torch::Tensor outputs              = torch::empty_like(inputs_q, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_q_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(input_lin_kv_results.data_ptr());\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Q Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_q_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_q_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_kv_dim, \n                             batches_kv, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             k_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_kv_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim_q, \n                             batch_stride_q, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(softmax_results.data_ptr()), \n                               static_cast<at::Half*>(dropout_results.data_ptr()), \n                               static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                               dropout_elems,\n                               (1.0f - dropout_prob));\n  }\n \n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_lin_q_results, \n           input_lin_kv_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_q_grads          = torch::empty_like(inputs_q);\n  torch::Tensor input_kv_grads         = torch::empty_like(inputs_kv);\n  torch::Tensor input_weight_q_grads   = torch::empty_like(input_weights_q);\n  torch::Tensor input_weight_kv_grads  = torch::empty_like(input_weights_kv);\n  torch::Tensor output_weight_grads    = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads          = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads             = torch::empty_like(dropout_results);\n  at::Tensor input_lin_q_output_grads  = torch::empty_like(input_lin_q_results);\n  at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;\n  \n  auto q_lin_grads_ptr   = static_cast<half*>(input_lin_q_output_grads.data_ptr());\n  auto k_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr());\n  auto v_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim_kv, \n                             batch_stride_kv,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n                             static_cast<at::Half*>(matmul2_grads.data_ptr()),\n                             static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Input Linear Q Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_q, \n                             output_lin_q_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_q_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Q Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_q_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_q.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_q_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_kv, \n                             output_lin_kv_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_kv_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_kv_dim,\n                             batches_kv, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_kv_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_q_grads, \n           input_kv_grads, \n           input_weight_q_grads, \n           input_weight_kv_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec \n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace encdec_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n                               bool                 use_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs_q.dim()               == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()              == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()   == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights_q.dim()        == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim()       == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()         == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs_q.type().scalarType()              == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()             == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()       == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half, \"Only HALF is supported\");\n  \n  if (use_mask) {\n    AT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n    AT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs_q, \n                                 inputs_kv,\n\t\t\t\t\t\t\t\t lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t\t lyr_nrm_beta_weights,\n                                 input_weights_q, \n                                 input_weights_kv, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_q_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_kv_results.dim()  == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_mean.dim()          == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_invvar.dim()        == 1, \"expected 1D tensor\");\n  AT_ASSERTM(inputs_q.dim()              == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs_kv.dim()             == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights_q.dim()       == 2, \"expected 2D tensor\");\n  AT_ASSERTM(input_weights_kv.dim()      == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()        == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_add_mask.dim()      == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()          == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_q_results.type().scalarType()   == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_kv_results.type().scalarType()  == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_mean.type().scalarType()          == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(lyr_nrm_invvar.type().scalarType()        == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(inputs_q.type().scalarType()              == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(inputs_kv.type().scalarType()             == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_q.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_weights_kv.type().scalarType()      == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()          == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  AT_ASSERTM(dropout_add_mask.type().scalarType()      == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  \n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_q_results, \n                                 input_lin_kv_results, \n                                 lyr_nrm_results,\n                                 lyr_nrm_mean,\n                                 lyr_nrm_invvar,\n                                 inputs_q, \n                                 inputs_kv, \n\t\t\t\t\t\t\t\t lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t\t lyr_nrm_beta_weights,\n                                 input_weights_q,\n                                 input_weights_kv,\n                                 output_weights,\n                                 dropout_mask,\n                                 dropout_add_mask,\n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec_norm_add \n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, \"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.\");\n  m.def(\"backward\", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, \"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace encdec_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   total_tokens_q    = batches_q * embed_dim;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options                   = inputs_q.options().requires_grad(false);\n  auto lyr_nrm_options               = act_options.dtype(torch::kFloat32);\n  auto mask_options                  = act_options.dtype(torch::kUInt8);\n  \n  torch::Tensor lyr_nrm_mean         = torch::empty({batches_q}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_invvar       = torch::empty({batches_q}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_results      = torch::empty_like(inputs_q, act_options);\n\n  torch::Tensor input_lin_q_results  = torch::empty({q_seq_len, sequences, output_lin_q_dim},  act_options);\n  torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);\n  torch::Tensor softmax_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_results      = torch::empty({attn_batches, q_seq_len, k_seq_len},      act_options);\n  torch::Tensor dropout_mask         = torch::empty({attn_batches, q_seq_len, k_seq_len},      mask_options);\n  torch::Tensor matmul2_results      = torch::empty({q_seq_len, attn_batches, head_dim},       act_options);\n  torch::Tensor output_lin_results   = torch::empty_like(inputs_q, act_options);\n  torch::Tensor dropout_add_mask     = torch::empty_like(inputs_q, mask_options);\n  torch::Tensor outputs              = torch::empty_like(inputs_q, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_q_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(input_lin_kv_results.data_ptr());\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Layer Norm\n  HostApplyLayerNorm<at::Half,float>(\n                             static_cast<at::Half*>(lyr_nrm_results.data_ptr()),\n                             static_cast<float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<float*>(lyr_nrm_invvar.data_ptr()),\n                             static_cast<const at::Half*>(inputs_q.data_ptr()),\n                             static_cast<int>(batches_q), // n1\n                             static_cast<int>(embed_dim), // n2\n                             1.0e-5,\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));\n\n  // Input Linear Q Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_q_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             //static_cast<const void*>(inputs_q.data_ptr()),\n                             static_cast<const void*>(lyr_nrm_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_q_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_kv_dim, \n                             batches_kv, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             k_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_kv_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim_q, \n                             batch_stride_q, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n  \n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(softmax_results.data_ptr()), \n                             static_cast<at::Half*>(dropout_results.data_ptr()), \n                             static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()), \n                             //static_cast<const half*>(dropout_results.data_ptr()), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // End-of-block Dropout-Add \n  if (is_training) {\n    apex_dropout_add_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                             static_cast<at::Half const*>(inputs_q.data_ptr()), \n                             static_cast<at::Half*>(outputs.data_ptr()), \n                             static_cast<uint8_t*>(dropout_add_mask.data_ptr()),\n                             total_tokens_q,\n                             (1.0f - dropout_prob));\n  } else {\n    apex_add_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                             static_cast<at::Half const*>(inputs_q.data_ptr()), \n                             static_cast<at::Half*>(outputs.data_ptr()), \n                             total_tokens_q);\n  }\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n\t\t   lyr_nrm_results,\n\t\t   lyr_nrm_mean,\n\t\t   lyr_nrm_invvar, \n           input_lin_q_results, \n           input_lin_kv_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n\t\t   dropout_add_mask, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results,\n                               torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs_q, \n                               torch::Tensor const& inputs_kv, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim         = inputs_q.size(2);\n  const int   sequences         = inputs_q.size(1);\n  const int   q_seq_len         = inputs_q.size(0);\n  const int   k_seq_len         = inputs_kv.size(0);\n  const int   batches_q         = sequences * q_seq_len;\n  const int   batches_kv        = sequences * k_seq_len;\n  const int   total_tokens_q    = batches_q * embed_dim;\n  const int   head_dim          = embed_dim / heads;\n  const int   output_lin_q_dim  = embed_dim;\n  const int   output_lin_kv_dim = 2 * embed_dim;\n  const int   attn_batches      = heads * sequences;\n  const int   lead_dim_q        = attn_batches * head_dim;\n  const int   lead_dim_kv       = attn_batches * 2 *head_dim;\n  const int   batch_stride_q    = head_dim;\n  const int   batch_stride_kv   = 2 * head_dim;\n  const int   dropout_elems     = attn_batches * q_seq_len * k_seq_len;\n  const float alpha             = 1.0;\n  const float beta              = 0.0;\n  const float scale             = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_q_grads          = torch::empty_like(inputs_q);\n  torch::Tensor input_kv_grads         = torch::empty_like(inputs_kv);\n  torch::Tensor lyr_nrm_gamma_grads    = torch::empty_like(lyr_nrm_gamma_weights);\n  torch::Tensor lyr_nrm_beta_grads     = torch::empty_like(lyr_nrm_beta_weights);\n  torch::Tensor input_weight_q_grads   = torch::empty_like(input_weights_q);\n  torch::Tensor input_weight_kv_grads  = torch::empty_like(input_weights_kv);\n  torch::Tensor output_weight_grads    = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor dropout_add_grads         = torch::empty_like(output_grads);\n  at::Tensor output_lin_grads          = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads             = torch::empty_like(dropout_results);\n  at::Tensor input_lin_q_output_grads  = torch::empty_like(input_lin_q_results);\n  at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);\n  at::Tensor input_lin_q_grads         = torch::empty_like(inputs_q);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;\n  \n  auto q_lin_grads_ptr   = static_cast<half*>(input_lin_q_output_grads.data_ptr());\n  auto k_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr());\n  auto v_lin_grads_ptr   = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  \n  // Dropout Add Backward  \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(dropout_add_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),\n\t\t\t\t\t\t\t total_tokens_q,\n                             (1.0 / (1.0 - dropout_prob)));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches_q, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim_kv, \n                             batch_stride_kv,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t\t\t\t\t dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim_q, \n                             batch_stride_q, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim_kv, \n                             batch_stride_kv, \n                             attn_batches);\n\n  // Input Linear Q Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_q, \n                             output_lin_q_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_q.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_q_dim, \n                             static_cast<const void*>(&beta),\n                             //static_cast<void*>(input_q_grads.data_ptr()),\n                             static_cast<void*>(input_lin_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Q Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_q_dim,\n                             batches_q, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_q.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_q_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_q_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches_kv, \n                             output_lin_kv_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights_kv.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_kv_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear KV Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_kv_dim,\n                             batches_kv, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs_kv.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(k_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_kv_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_kv_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Fused Layer Norm Bwd with Residual Add\n  HostLayerNormGradient<half,float>(\n                             static_cast<const half*>(input_lin_q_grads.data_ptr()),\n                             static_cast<half const*>(output_grads.data_ptr()), \n                             static_cast<const float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<const float*>(lyr_nrm_invvar.data_ptr()),\n                             inputs_q,\n                             static_cast<int>(batches_q),  // n1\n                             static_cast<int>(embed_dim),  // n2\n                             static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),\n                             static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()),\n                             1.0e-5,\n                             static_cast<half*>(input_q_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_beta_grads.data_ptr())\n                                  );\n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_q_grads, \n           input_kv_grads, \n           lyr_nrm_gamma_grads, \n           lyr_nrm_beta_grads, \n           input_weight_q_grads, \n           input_weight_kv_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace encdec_norm_add \n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/layer_norm.h",
    "content": "#include \"ATen/ATen.h\"\n#include <THC/THCDeviceUtils.cuh>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\ntemplate<typename U> __device__\nvoid cuWelfordOnlineSum(\n  const U curr,\n  U& mu,\n  U& sigma2,\n  U& count)\n{\n  count = count + U(1);\n  U delta = curr - mu;\n  U lmean = mu + delta / count;\n  mu = lmean;\n  U delta2 = curr - lmean;\n  sigma2 = sigma2 + delta * delta2;\n}\n\ntemplate<typename U> __device__\nvoid cuChanOnlineSum(\n  const U muB,\n  const U sigma2B,\n  const U countB,\n  U& mu,\n  U& sigma2,\n  U& count)\n{\n  U delta = muB - mu;\n  U nA = count;\n  U nB = countB;\n  count = count + countB;\n  U nX = count;\n  if (nX > U(0)) {\n    nA = nA / nX;\n    nB = nB / nX;\n    mu = nA*mu + nB*muB;\n    sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;\n  } else {\n    mu = U(0);\n    sigma2 = U(0);\n  }\n}\n\ntemplate<typename T, typename U> __device__\nvoid cuWelfordMuSigma2(\n  const T* __restrict__ vals,\n  const int n1,\n  const int n2,\n  const int i1,\n  U& mu,\n  U& sigma2,\n  U* buf) \n{\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  U count = U(0);\n  mu= U(0);\n  sigma2 = U(0);\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const T* lvals = vals + i1*n2;\n    int l = 4*thrx;\n    for (;  l+3 < n2;  l+=4*numx) {\n      for (int k = 0;  k < 4;  ++k) {\n        U curr = static_cast<U>(lvals[l+k]);\n        cuWelfordOnlineSum<U>(curr,mu,sigma2,count);\n      }\n    }\n    for (;  l < n2;  ++l) {\n      U curr = static_cast<U>(lvals[l]);\n      cuWelfordOnlineSum<U>(curr,mu,sigma2,count);\n    }\n    // intra-warp reductions\n    for (int l = 0;  l <= 4;  ++l) {\n      int srcLaneB = (threadIdx.x+(1<<l))&31;\n      U muB = WARP_SHFL(mu, srcLaneB);\n      U countB = WARP_SHFL(count, srcLaneB);\n      U sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      U* ubuf = (U*)buf;\n      U* ibuf = (U*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int wrt_y = threadIdx.y - offset;\n          ubuf[2*wrt_y] = mu;\n          ubuf[2*wrt_y+1] = sigma2;\n          ibuf[wrt_y] = count;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          U muB = ubuf[2*threadIdx.y];\n          U sigma2B = ubuf[2*threadIdx.y+1];\n          U countB = ibuf[threadIdx.y];\n          cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        ubuf[0] = mu;\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      mu = ubuf[0];\n      sigma2 = ubuf[1]/U(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      mu = WARP_SHFL(mu, 0);\n      sigma2 = WARP_SHFL(sigma2/U(n2), 0);\n    }\n  }\n}\n\ntemplate<> __device__\nvoid cuWelfordMuSigma2(\n  const at::Half* __restrict__ vals,\n  const int n1,\n  const int n2,\n  const int i1,\n  float& mu,\n  float& sigma2,\n  float* buf) \n{\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  float count = 0.0f;\n  mu= float(0);\n  sigma2 = float(0);\n\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const at::Half* lvals = vals + i1*n2;\n    int l = 8*thrx;\n    if ((((size_t)lvals)&3) != 0) {\n      // 16 bit alignment\n      // first thread consumes first point\n      if (thrx == 0) {\n        float curr = static_cast<float>(lvals[0]);\n        cuWelfordOnlineSum(curr,mu,sigma2,count);\n      }\n      ++l;\n    }\n    // at this point, lvals[l] are 32 bit aligned for all threads.\n    for (;  l+7 < n2;  l+=8*numx) {\n      for (int k = 0;  k < 8;  k+=2) {\n        float2 curr = __half22float2(*((__half2*)(lvals+l+k)));\n        cuWelfordOnlineSum(curr.x,mu,sigma2,count);\n\tcuWelfordOnlineSum(curr.y,mu,sigma2,count);\n      }\n    }\n    for (;  l < n2;  ++l) {\n      float curr = static_cast<float>(lvals[l]);\n      cuWelfordOnlineSum(curr,mu,sigma2,count);\n    }\n    // intra-warp reductions\n    for (int l = 0;  l <= 4;  ++l) {\n      int srcLaneB = (threadIdx.x+(1<<l))&31;\n      float muB = WARP_SHFL(mu, srcLaneB);\n      float countB = WARP_SHFL(count, srcLaneB);\n      float sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      float* ubuf = (float*)buf;\n      float* ibuf = (float*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int wrt_y = threadIdx.y - offset;\n          ubuf[2*wrt_y] = mu;\n          ubuf[2*wrt_y+1] = sigma2;\n          ibuf[wrt_y] = count;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          float muB = ubuf[2*threadIdx.y];\n          float sigma2B = ubuf[2*threadIdx.y+1];\n          float countB = ibuf[threadIdx.y];\n          cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        ubuf[0] = mu;\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      mu = ubuf[0];\n      sigma2 = ubuf[1]/float(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      mu = WARP_SHFL(mu, 0);\n      sigma2 = WARP_SHFL(sigma2/float(n2), 0);\n    }\n  }\n}\n\ntemplate<typename U> U rsqrt(U v) {\n  return U(1) / sqrt(v);\n}\ntemplate<> float rsqrt(float v) {\n  return rsqrtf(v);\n}\ntemplate<> double rsqrt(double v) {\n  return rsqrt(v);\n}\n\nnamespace {\n// This is the un-specialized struct.  Note that we prevent instantiation of this\n// struct by putting an undefined symbol in the function body so it won't compile.\n//  template <typename T>\n//  struct SharedMemory\n//  {\n//      // Ensure that we won't compile any un-specialized types\n//      __device__ T *getPointer()\n//      {\n//          extern __device__ void error(void);\n//          error();\n//          return NULL;\n//      }\n//  };\n// https://github.com/NVIDIA/apex/issues/246\ntemplate <typename T>\nstruct SharedMemory;\n\ntemplate <>\nstruct SharedMemory <float>\n{\n    __device__ float *getPointer()\n    {\n        extern __shared__ float s_float[];\n        return s_float;\n    }\n};\n\ntemplate <>\nstruct SharedMemory <double>\n{\n    __device__ double *getPointer()\n    {\n        extern __shared__ double s_double[];\n        return s_double;\n    }\n};\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuApplyLayerNorm(\n  T* __restrict__ output_vals,\n  U* __restrict__ mean,\n  U* __restrict__ invvar,\n  const T* __restrict__ vals,\n  const int n1,\n  const int n2,\n  const U epsilon,\n  const T* __restrict__ gamma,\n  const T* __restrict__ beta\n  ) \n{\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensors are contiguous\n  //\n  for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer();\n    U mu,sigma2;\n    cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);\n    const T* lvals = vals + i1*n2;\n    T* ovals = output_vals + i1*n2;\n    U c_invvar = rsqrt(sigma2 + epsilon);\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != NULL && beta != NULL) {\n      for (int i = thrx;  i < n2;  i+=numx) {\n        U curr = static_cast<U>(lvals[i]);\n        ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];\n      }\n    } else {\n      for (int i = thrx;  i < n2;  i+=numx) {\n        U curr = static_cast<U>(lvals[i]);\n        ovals[i] = static_cast<T>(c_invvar * (curr - mu));\n      }\n    }\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      mean[i1] = mu;\n      invvar[i1] = c_invvar;\n    }\n  }\n}\n\ntemplate<typename T, typename U> __device__\nvoid cuLoadWriteStridedInputs(\n    const int i1_block,\n    const int thr_load_row_off,\n    const int thr_load_col_off,\n    const int i2_off,\n    const int row_stride,\n    U* warp_buf1,\n    U* warp_buf2,\n    const T* input,\n    const T* dout,\n    const int i1_end,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar\n    )\n{\n  int i1 = i1_block+thr_load_row_off;\n  if (i1 < i1_end) {\n    U curr_mean = mean[i1];\n    U curr_invvar = invvar[i1];\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1*n2+i2;\n      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;\n      if (i2<n2) {\n        U curr_input = static_cast<U>(input[load_idx]);\n\tU curr_dout = static_cast<U>(dout[load_idx]);\n\twarp_buf1[write_idx] = curr_dout;\n\twarp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;\n      } else {\n        warp_buf1[write_idx] = U(0);\n        warp_buf2[write_idx] = U(0);\n      }\n    }\n  } else {\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;\n      warp_buf1[write_idx] = U(0);\n      warp_buf2[write_idx] = U(0);\n    }\n  }\n}\n\ntemplate<typename T, typename U> __device__\nvoid cuLoadAddStridedInputs(\n    const int i1_block,\n    const int thr_load_row_off,\n    const int thr_load_col_off,\n    const int i2_off,\n    const int row_stride,\n    U* warp_buf1,\n    U* warp_buf2,\n    const T* input,\n    const T* dout,\n    const int i1_end,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar\n    )\n{\n  int i1 = i1_block+thr_load_row_off;\n  if (i1 < i1_end) {\n    U curr_mean = mean[i1];\n    U curr_invvar = invvar[i1];\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1*n2+i2;\n      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;\n      if (i2<n2) {\n        U curr_input = static_cast<U>(input[load_idx]);\n\tU curr_dout = static_cast<U>(dout[load_idx]);\n\twarp_buf1[write_idx] += curr_dout;\n\twarp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuComputePartGradGammaBeta(\n    const T* __restrict__ dout,\n    const T* __restrict__ input,\n    const int n1,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar,\n    U epsilon,\n    U* part_grad_gamma,\n    U* part_grad_beta)\n{\n    const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);\n    const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;\n    const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;\n    const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;\n    const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;\n    const int row_stride = blockDim.x+1;\n    const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);\n    const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;\n    const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements\n    U* warp_buf1 = (U*)buf;\n    U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;\n    // compute partial sums from strided inputs\n    // do this to increase number of loads in flight\n    cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);\n    for (int i1_block = i1_beg+blockDim.y*blockDim.y;  i1_block < i1_end;  i1_block+=blockDim.y*blockDim.y) {\n      cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);\n    }\n    __syncthreads();\n    // inter-warp reductions\n    // sum within each warp\n    U acc1 = U(0);\n    U acc2 = U(0);\n    for (int k = 0;  k < blockDim.y;  ++k) {\n      int row1 = threadIdx.y + k*blockDim.y;\n      int idx1 = row1*row_stride + threadIdx.x;\n      acc1 += warp_buf1[idx1];\n      acc2 += warp_buf2[idx1];\n    }\n    warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;\n    warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;\n    __syncthreads();\n    // sum all warps\n    for (int offset = blockDim.y/2;  offset > 1;  offset /= 2) {\n      if (threadIdx.y < offset) {\n        int row1 = threadIdx.y;\n\tint row2 = threadIdx.y + offset;\n\tint idx1 = row1*row_stride + threadIdx.x;\n\tint idx2 = row2*row_stride + threadIdx.x;\n\twarp_buf1[idx1] += warp_buf1[idx2];\n\twarp_buf2[idx1] += warp_buf2[idx2];\n      }\n      __syncthreads();\n    }\n    int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n    if (threadIdx.y == 0 && i2 < n2) {\n      int row1 = threadIdx.y;\n      int row2 = threadIdx.y + 1;\n      int idx1 = row1*row_stride + threadIdx.x;\n      int idx2 = row2*row_stride + threadIdx.x;\n      part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];\n      part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];\n    }\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuComputeGradGammaBeta(\n    const U* part_grad_gamma,\n    const U* part_grad_beta,\n    const int part_size,\n    const int n1,\n    const int n2,\n    T* grad_gamma,\n    T* grad_beta)\n{\n    // sum partial gradients for gamma and beta\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer(); \n    int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i2 < n2) {\n      // each warp does sequential reductions until reduced part_size is num_warps\n      int num_warp_reductions = part_size / blockDim.y;\n      U sum_gamma = U(0);\n      U sum_beta = U(0);\n      const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;\n      const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;\n      for (int warp_offset = 0;  warp_offset < num_warp_reductions;  ++warp_offset) {\n        sum_gamma += part_grad_gamma_ptr[warp_offset*n2];\n        sum_beta += part_grad_beta_ptr[warp_offset*n2];\n      }\n      // inter-warp reductions\n      const int nbsize3 = blockDim.x * blockDim.y / 2;\n      for (int offset = blockDim.y/2;  offset >= 1;  offset /= 2) {\n        // top half write to shared memory\n        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n          buf[write_idx] = sum_gamma;\n          buf[write_idx+nbsize3] = sum_beta;\n        }\n        __syncthreads();\n        // bottom half sums\n        if (threadIdx.y < offset) {\n          const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;\n          sum_gamma += buf[read_idx];\n          sum_beta += buf[read_idx+nbsize3];\n        }\n        __syncthreads();\n      }\n      // write out fully summed gradients\n      if (threadIdx.y == 0) {\n        grad_gamma[i2] = sum_gamma;\n        grad_beta[i2] = sum_beta;\n      }\n    }\n}\n\ntemplate<typename T, typename U> __global__\nvoid cuComputeGradInput(\n    const T* __restrict__ dout,\n    const T* __restrict__ dout_resid,\n    const T* __restrict__ input,\n    const int n1,\n    const int n2,\n    const U* __restrict__ mean,\n    const U* __restrict__ invvar,\n    U epsilon,\n    const T* gamma,\n    T* grad_input)\n{\n  for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    U sum_loss1 = U(0);\n    U sum_loss2 = U(0);\n    const U c_mean = mean[i1];\n    const U c_invvar = invvar[i1];\n    const T* k_input = input + i1*n2;\n    const T* k_dout = dout + i1*n2;\n    const T* k_dout_resid = dout_resid + i1*n2;\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != NULL) {\n      int l = 4*thrx;\n      for (;  l+3 < n2;  l+=4*numx) {\n        for (int k = 0;  k < 4;  ++k) {\n          const U c_h = static_cast<U>(k_input[l+k]);\n          const U c_loss = static_cast<U>(k_dout[l+k]);\n          sum_loss1 += c_loss * static_cast<U>(gamma[l+k]);\n          sum_loss2 += c_loss * static_cast<U>(gamma[l+k]) * (c_h - c_mean) * c_invvar;\n        }\n      }\n      for (;  l < n2;  ++l) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        sum_loss1 += c_loss * static_cast<U>(gamma[l]);\n        sum_loss2 += c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;\n      }\n    } else {\n      int l = 4*thrx;\n      for (;  l+3 < n2;  l+=4*numx) {\n        for (int k = 0;  k < 4;  ++k) {\n          const U c_h = static_cast<U>(k_input[l+k]);\n          const U c_loss = static_cast<U>(k_dout[l+k]);\n          sum_loss1 += c_loss;\n          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n        }\n      }\n      for (;  l < n2;  ++l) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        sum_loss1 += c_loss;\n        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n      }\n    }\n    // intra-warp reductions\n    for (int mask = blockDim.x/2;  mask > 0;  mask /= 2) {\n      sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);\n      sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);\n    }\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      SharedMemory<U> shared;\n      U* buf = shared.getPointer(); \n      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {\n          const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n          buf[2*wrt_i] = sum_loss1;\n          buf[2*wrt_i+1] = sum_loss2;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.y < offset) {\n          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;\n          sum_loss1 += buf[2*read_i];\n          sum_loss2 += buf[2*read_i+1];\n        }\n        __syncthreads();\n      }\n      if (threadIdx.y == 0) {\n        buf[2*threadIdx.x] = sum_loss1;\n        buf[2*threadIdx.x+1] = sum_loss2;\n      }\n      __syncthreads();\n      if (threadIdx.y !=0) {\n        sum_loss1 = buf[2*threadIdx.x];\n        sum_loss2 = buf[2*threadIdx.x+1];\n      } \n    }\n    // all threads now have the two sums over l\n    U fH = (U)n2;\n    U term1 = (U(1) / fH) * c_invvar;\n    T* k_grad_input = grad_input + i1*n2;\n    if (gamma != NULL) {\n      for (int l = thrx;  l < n2;  l+=numx) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        const T c_resid= static_cast<T>(k_dout_resid[l]);\n        U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);\n        f_grad_input -= sum_loss1;\n        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid;\n      }\n    } else {\n      for (int l = thrx;  l < n2;  l+=numx) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        const T c_resid= static_cast<T>(k_dout_resid[l]);\n        U f_grad_input = fH * c_loss;\n        f_grad_input -= sum_loss1;\n        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid;\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename U> \nvoid HostApplyLayerNorm(\n    T* output,\n    U* mean,\n    U* invvar,\n    const T* input,\n    int n1,\n    int n2,\n    double epsilon,\n    const T* gamma,\n    const T* beta\n    )\n{\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n    const dim3 threads(32,4,1);\n    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n    const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);\n    int nshared = \n        threads.y > 1 ? \n\t    threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : \n\t    0;\n    cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(\n\t\t    output,\n\t\t    mean,\n\t\t    invvar,\n\t\t    input,\n\t\t    n1,n2,\n\t\t    U(epsilon),\n            gamma,beta);\n}\n\ntemplate<typename T, typename U> \nvoid HostLayerNormGradient(\n    const T* dout,\n    const T* dout_resid,\n    const U* mean,\n    const U* invvar,\n    const at::Tensor& input,\n    int n1,\n    int n2,\n    const T* gamma,\n    const T* beta,\n    double epsilon,\n    T* grad_input,\n    T* grad_gamma,\n    T* grad_beta\n    )\n{\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    if (gamma != NULL && beta != NULL) {\n      // compute grad_gamma(j) and grad_beta(j)\n      const int part_size = 16;\n      const dim3 threads2(32,4,1);\n      const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);\n      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);\n      const int nshared2_b = threads2.x * threads2.y * sizeof(U);\n      const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;\n      at::Tensor part_grad_gamma = at::empty({part_size,n2}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));\n      at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);\n      cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(\n\t\t      dout,\n\t\t      static_cast<T*>(input.data_ptr()),\n\t\t      n1,n2,\n\t\t      mean,\n\t\t      invvar,\n\t\t      U(epsilon),\n\t\t      static_cast<U*>(part_grad_gamma.data_ptr()),\n\t\t      static_cast<U*>(part_grad_beta.data_ptr()));\n\n      const dim3 threads3(32,8,1);\n      const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);\n      const int nshared3 = threads3.x * threads3.y * sizeof(U);\n      cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(\n\t\t      static_cast<U*>(part_grad_gamma.data_ptr()),\n\t\t      static_cast<U*>(part_grad_beta.data_ptr()),\n\t\t      part_size,\n\t\t      n1,n2,\n\t\t      grad_gamma,\n\t\t      grad_beta);\n    }\n\n    // compute grad_input\n    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n    const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);\n    const dim3 threads1(32,4,1);\n    int nshared =\n\t    threads1.y > 1 ?\n\t    threads1.y*threads1.x*sizeof(U) :\n\t    0;\n    cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(\n            dout,\n\t    dout_resid,\n            static_cast<T*>(input.data_ptr()),\n            n1,n2,\n            mean,\n            invvar,\n            U(epsilon),\n            gamma,\n            grad_input);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               const uint8_t *padding_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t       bool \t\t\t\tuse_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(input.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(\n                                 is_training,\n                                 heads, \n                                 input, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\ntorch::Tensor bwd(\n\t\t               bool use_mask,\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& padding_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n//  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n\t\t                 heads,\n                                 output_grads,\n                                 softmax_results, \n                                 dropout_mask, \n                                 use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace mask_softmax_dropout\n} // end namespace fused_softmax\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, \"Self Multihead Attention masked softmax dropout -- Forward.\");\n  m.def(\"backward\", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, \"Self Multihead Attention masked softmax dropout -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"softmax.h\"\n#include \"dropout.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(\n\t\t\t       bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& input, \n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = input.size(0);\n  const int   sequences      = attn_batches / heads;\n  const int   q_seq_len      = input.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = input.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(input_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n  }\n\n\n  if (is_training) {\n    //use at:: function so that C++ version generates the same random mask as python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n\n  return {\n           dropout_results,  \n           dropout_mask, \n           softmax_results\n         };\n}\n\ntorch::Tensor bwd_cuda(\n\t\t               int heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& softmax_results, \n                               torch::Tensor const& dropout_mask,\n                               const uint8_t  *padding_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   attn_batches   = output_grads.size(0);\n  const int   q_seq_len      = output_grads.size(1);\n  const int   k_seq_len      = q_seq_len;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n//  torch::Tensor input_grads         = torch::empty_like(output_grads);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  if (padding_mask == nullptr) {\n      dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(\n                             static_cast<half*>(output_grads.data_ptr()), \n                             static_cast<half*>(output_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len, stream);\n  } else{\n      dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(\n                             static_cast<half*>(output_grads.data_ptr()), \n                             static_cast<half*>(output_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(padding_mask),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n\t\t\t     heads, stream); \n  \n  }\n//backward pass is completely in-place\n  return output_grads;\n}\n}\n}\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/philox.h",
    "content": "#pragma once\n//Philox CUDA. \n\nclass Philox {\npublic:\n  __device__ inline Philox(unsigned long long seed,\n                           unsigned long long subsequence,\n                           unsigned long long offset) {\n    key.x = (unsigned int)seed;\n    key.y = (unsigned int)(seed >> 32);\n    counter = make_uint4(0, 0, 0, 0);\n    counter.z = (unsigned int)(subsequence);\n    counter.w = (unsigned int)(subsequence >> 32);\n    STATE = 0;\n    incr_n(offset / 4);\n  }\n  __device__ inline uint4 operator()() {\n    if(STATE == 0) {\n      uint4 counter_ = counter;\n      uint2 key_ = key;\n      //7-round philox\n      for(int i = 0; i < 6; i++) {\n        counter_ = single_round(counter_, key_);\n        key_.x += (kPhilox10A); key_.y += (kPhilox10B);\n      }\n      output = single_round(counter_, key_);\n      incr();\n    }\n    //return a float4 directly\n    //unsigned long ret;\n    //switch(STATE) {\n    //  case 0: ret = output.x; break;\n    //  case 1: ret = output.y; break;\n    //  case 2: ret = output.z; break;\n    //  case 3: ret = output.w; break;\n    //}\n    //STATE = (STATE + 1) % 4;\n    return output;\n  }\nprivate:\n  uint4 counter;\n  uint4 output;\n  uint2 key;\n  unsigned int STATE;\n  __device__ inline void incr_n(unsigned long long n) {\n    unsigned int nlo = (unsigned int)(n);\n    unsigned int nhi = (unsigned int)(n >> 32);\n    counter.x += nlo;\n    if (counter.x < nlo)\n      nhi++;\n    counter.y += nhi;\n    if (nhi <= counter.y)\n      return;\n    if (++counter.z)\n      return;\n    ++counter.w;\n  }\n  __device__ inline void incr() {\n    if (++counter.x)\n      return;\n    if (++counter.y)\n      return;\n    if (++counter.z)\n      return;\n    ++counter.w;\n  }\n  __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,\n                                    unsigned int *result_high) {\n    *result_high = __umulhi(a, b);\n    return a*b;\n  }\n  __device__ inline uint4 single_round(uint4 ctr, uint2 key) {\n    unsigned int hi0;\n    unsigned int hi1;\n    unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);\n    unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);\n    uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};\n    return ret;\n  }\n  static const unsigned long kPhilox10A = 0x9E3779B9;\n  static const unsigned long kPhilox10B = 0xBB67AE85;\n  static const unsigned long kPhiloxSA = 0xD2511F53;\n  static const unsigned long kPhiloxSB = 0xCD9E8D57;\n};\n// Inverse of 2^32.\n#define M_RAN_INVM32 2.3283064e-10f\n__device__  __inline__ float4 uniform4(uint4 x) {\n    return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);\n\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace self {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t\t\t\t\t   bool \t\t\t\tuse_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs, \n                                 input_weights, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()    == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()    == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n  \n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_results, \n                                 inputs, \n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self::cublas_gemmex::fwd, \"Self Multihead Attention Forward.\");\n  m.def(\"backward\", &multihead_attn::self::cublas_gemmex::bwd, \"Self Multihead Attention Backward.\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace self_bias {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               //torch::Tensor const& input_biases,\n                               //torch::Tensor const& output_biases,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t\t\t\t\t   bool \t\t\t\tuse_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs, \n                                 input_weights, \n                                 output_weights, \n                                 input_biases, \n                                 output_biases, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()    == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()    == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_results, \n                                 inputs, \n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self_bias::cublas_gemmex::fwd, \"Self Multihead Attention with Bias -- Forward.\");\n  m.def(\"backward\", &multihead_attn::self_bias::cublas_gemmex::bwd, \"Self Multihead Attention with Bias -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n#include <cuda_fp16.h>\n\nnamespace multihead_attn {\nnamespace self_bias_additive_mask {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,  \n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const half*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                              // torch::Tensor const& softmax_results,\n                               torch::Tensor const& bmm1_results,\n                               torch::Tensor const& pad_mask,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               //torch::Tensor const& input_biases,\n                               //torch::Tensor const& output_biases,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n \t\t\t\t\t\t\t   bool \t\t\t\tuse_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()         == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()  == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(use_mask                                                  , \"no mask is not supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Half, \"Only Half is supported\");\n  }\n\n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs, \n                                 input_weights, \n                                 output_weights, \n                                 input_biases, \n                                 output_biases, \n                                 use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& bmm1_results,\n                               torch::Tensor const& pad_mask,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()      == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()   == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  AT_ASSERTM(inputs.dim()            == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_weights.dim()     == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()    == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()      == 3, \"expected 3D tensor\");\n\n  AT_ASSERTM(output_grads.type().scalarType()      == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()   == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(inputs.type().scalarType()            == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()     == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()    == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()      == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(\n                                 heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n\t\t\t\t bmm1_results,\n\t\t\t\t pad_mask, \n                                 input_lin_results, \n                                 inputs, \n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, \"Self Multihead Attention with Bias -- Forward.\");\n  m.def(\"backward\", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, \"Self Multihead Attention with Bias -- Backward.\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self_bias_additive_mask {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const half*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta_zero       = 0.0;\n  const float beta_one           = 1.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor bmm1_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());\n  void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  input_lin_results.copy_(input_biases);\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta_zero, \n                             static_cast<half*>(bmm1_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n  // Padded Softmax\n  bool softmax_success = false;\n  if (is_training) {\n      softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(\n                           reinterpret_cast<half*>(dropout_results_ptr),\n                           (is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,\n                           reinterpret_cast<const half*>(bmm1_results_ptr),\n                           pad_mask,\n      \t\t           attn_batches*q_seq_len*q_seq_len,\n                           k_seq_len,\n                           k_seq_len,\n                           attn_batches*q_seq_len,\n                           attn_batches*q_seq_len/sequences, \n      \t\t           1.0f-dropout_prob,\n\t\t           stream);\n  } else {\n      softmax_success = dispatch_additive_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function\n                             reinterpret_cast<const half*>(bmm1_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(dropout_results.data_ptr()), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta_zero, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  outputs.copy_(output_biases);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n           input_lin_results,  \n           bmm1_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& bmm1_results,\n                               torch::Tensor const& pad_mask,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads         = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads  = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half* const>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(bmm1_results.data_ptr()),\n                             reinterpret_cast<half const*>(pad_mask.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n\t\t\t     attn_batches*q_seq_len/sequences,\n                             attn_batches*q_seq_len,\n\t\t\t     stream);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n\t\t\t     static_cast<const void*>(input_lin_output_grads.data_ptr()),\n                             //static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_grads, \n           input_weight_grads, \n           output_weight_grads,\n           input_bias_grads, \n           output_bias_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self_bias {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta_zero       = 0.0;\n  const float beta_one           = 1.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  input_lin_results.copy_(input_biases);\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta_zero, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n\n\n  if (is_training) {\n    //use at:: function so that C++ version generates the same random mask as python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta_zero, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  outputs.copy_(output_biases);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta_one),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO1_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n           input_lin_results,  \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads         = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads  = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n\t\t\t     static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t     1.0/(1.0-dropout_prob),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len, stream);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n\t\t\t     static_cast<const void*>(input_lin_output_grads.data_ptr()),\n                             //static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto  input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_grads, \n           input_weight_grads, \n           output_weight_grads,\n           input_bias_grads, \n           output_bias_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options  = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(softmax_results.data_ptr()),\n                               static_cast<at::Half*>(dropout_results.data_ptr()),\n                               static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                               dropout_elems,\n                               (1.0f - dropout_prob));\n  }\n \n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()), \n                             head_dim*attn_batches, \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(outputs.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_lin_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_grads         = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads  = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n  \n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(output_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n                             static_cast<at::Half*>(matmul2_grads.data_ptr()),\n                             static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(inputs.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           input_grads, \n           input_weight_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self\n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\nnamespace multihead_attn {\nnamespace self_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                                  );\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs, \n                               torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  );\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> fwd(\n                               bool                 use_mask,\n                               bool                 use_time_mask,\n                               bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& pad_mask,\n                               float                dropout_prob\n                                                 )\n{\n  AT_ASSERTM(inputs.dim()                 == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()   == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights.dim()          == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()         == 2, \"expected 2D tensor\");\n\n  AT_ASSERTM(inputs.type().scalarType()                == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()         == at::ScalarType::Half, \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n  \tAT_ASSERTM(pad_mask.dim()                     == 2,                    \"expected 2D tensor\");\n  \tAT_ASSERTM(pad_mask.type().scalarType()       == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n  \n  return fwd_cuda(\n                                 use_time_mask,\n                                 is_training,\n                                 heads, \n                                 inputs,\n                                 lyr_nrm_gamma_weights,\n                                 lyr_nrm_beta_weights,\n                                 input_weights, \n                                 output_weights, \n                                 use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, \n                                 dropout_prob\n                                );\n}\n\n\nstd::vector<torch::Tensor> bwd(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                                  )\n{\n  AT_ASSERTM(output_grads.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(matmul2_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(softmax_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(input_lin_results.dim()     == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_results.dim()       == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_mean.dim()          == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_invvar.dim()        == 1, \"expected 1D tensor\");\n  AT_ASSERTM(inputs.dim()                == 3, \"expected 3D tensor\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  AT_ASSERTM(lyr_nrm_beta_weights.dim()  == 1, \"expected 1D tensor\");\n  AT_ASSERTM(input_weights.dim()         == 2, \"expected 2D tensor\");\n  AT_ASSERTM(output_weights.dim()        == 2, \"expected 2D tensor\");\n  AT_ASSERTM(dropout_mask.dim()          == 3, \"expected 3D tensor\");\n  AT_ASSERTM(dropout_add_mask.dim()      == 3, \"expected 3D tensor\");\n  \n  AT_ASSERTM(output_grads.type().scalarType()          == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(matmul2_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(softmax_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_lin_results.type().scalarType()     == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_results.type().scalarType()       == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_mean.type().scalarType()          == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(lyr_nrm_invvar.type().scalarType()        == at::ScalarType::Float, \"Only FLOAT is supported\");\n  AT_ASSERTM(inputs.type().scalarType()                == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType()  == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(input_weights.type().scalarType()         == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(output_weights.type().scalarType()        == at::ScalarType::Half,  \"Only HALF is supported\");\n  AT_ASSERTM(dropout_mask.type().scalarType()          == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  AT_ASSERTM(dropout_add_mask.type().scalarType()      == at::ScalarType::Byte,  \"Only BYTE is supported\");\n  \n  return bwd_cuda(heads, \n                                 output_grads,\n                                 matmul2_results,\n                                 dropout_results,\n                                 softmax_results, \n                                 input_lin_results, \n                                 lyr_nrm_results,\n                                 lyr_nrm_mean,\n                                 lyr_nrm_invvar,\n                                 inputs, \n\t\t\t\t\t\t\t     lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t\t lyr_nrm_beta_weights,\n                                 input_weights,\n                                 output_weights,\n                                 dropout_mask, \n                                 dropout_add_mask,\n                                 dropout_prob\n                                );\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self_norm_add \n} // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::self_norm_add::cublas_gemmex::fwd, \"Self Multihead Attention Plus Layer Norm and Residual Add Forward.\");\n  m.def(\"backward\", &multihead_attn::self_norm_add::cublas_gemmex::bwd, \"Self Multihead Attention Plus Layer Norm and Residual Add Backward.\");\n}\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu",
    "content": "#include <vector>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include <math.h>\n\n#include \"strided_batched_gemm.h\"\n#include \"softmax.h\"\n#include \"dropout.h\"\n#include \"layer_norm.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\nnamespace multihead_attn {\nnamespace self_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(\n                               bool                 use_time_mask,\n\t\t\t\t\t\t\t   bool                 is_training,\n                               int                  heads,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               const uint8_t*       pad_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   total_tokens   = batches * embed_dim;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n \n  // There is no reason to use more than one stream as every kernel is \n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)\n  auto act_options                = inputs.options().requires_grad(false);\n  auto lyr_nrm_options            = act_options.dtype(torch::kFloat32);\n  auto mask_options               = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor lyr_nrm_mean      = torch::empty({batches}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_invvar    = torch::empty({batches}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_results   = torch::empty_like(inputs, act_options);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_results   = torch::empty({attn_batches, q_seq_len, k_seq_len},   act_options);\n  torch::Tensor dropout_mask      = torch::empty({attn_batches, q_seq_len, k_seq_len},   mask_options);\n  torch::Tensor matmul2_results   = torch::empty({q_seq_len, attn_batches, head_dim},    act_options);\n  torch::Tensor output_lin_results= torch::empty_like(inputs, act_options);\n  torch::Tensor dropout_add_mask  = torch::empty_like(inputs, mask_options);\n  torch::Tensor outputs           = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr   = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr   = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n  \n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Layer Norm\n  HostApplyLayerNorm<at::Half,float>(\n                             static_cast<at::Half*>(lyr_nrm_results.data_ptr()),\n                             static_cast<float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<float*>(lyr_nrm_invvar.data_ptr()),\n                             static_cast<const at::Half*>(inputs.data_ptr()),\n                             static_cast<int>(batches),   // n1\n                             static_cast<int>(embed_dim), // n2\n                             1.0e-5,\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),\n\t\t\t\t\t\t\t static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));\n\n  // Input Linear Fwd\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             output_lin_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             //static_cast<const void*>(inputs.data_ptr()),\n                             static_cast<const void*>(lyr_nrm_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             q_lin_results_ptr,\n                             CUDA_R_16F, \n                             output_lin_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             scale, \n                             static_cast<const half*>(k_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             static_cast<const half*>(q_lin_results_ptr),\n                             lead_dim, \n                             batch_stride, \n                             beta, \n                             static_cast<half*>(softmax_results_ptr), \n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n                             reinterpret_cast<half*>(softmax_results_ptr),\n                             reinterpret_cast<const half*>(softmax_results_ptr),\n                             pad_mask,\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len,\n                             attn_batches*q_seq_len/sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(softmax_results.data_ptr()), \n                             static_cast<at::Half*>(dropout_results.data_ptr()), \n                             static_cast<uint8_t*>(dropout_mask.data_ptr()),\n                             dropout_elems,\n                             (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr), \n                             lead_dim, \n                             batch_stride, \n                             (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , \n                             //static_cast<const half*>(dropout_results.data_ptr()), \n                             k_seq_len,  \n                             k_seq_len*q_seq_len, \n                             beta, \n                             static_cast<half*>(matmul2_results.data_ptr()),  \n                             head_dim*attn_batches,  \n                             head_dim, \n                             attn_batches);\n\n  // Output Linear\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_T, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // End-of-block Dropout-Add \n  if (is_training) {\n    apex_dropout_add_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                               static_cast<at::Half const*>(inputs.data_ptr()), \n                               static_cast<at::Half*>(outputs.data_ptr()), \n                               static_cast<uint8_t*>(dropout_add_mask.data_ptr()),\n                               total_tokens,\n                               (1.0f - dropout_prob));\n  } else {\n    apex_add_cuda<at::Half,float,uint32_t>(\n                               static_cast<at::Half const*>(output_lin_results.data_ptr()), \n                               static_cast<at::Half const*>(inputs.data_ptr()), \n                               static_cast<at::Half*>(outputs.data_ptr()), \n                               total_tokens);\n  }\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return { \n           lyr_nrm_results,\n\t\t   lyr_nrm_mean,\n           lyr_nrm_invvar, \n           input_lin_results, \n           softmax_results,\n           dropout_results, \n           dropout_mask, \n           matmul2_results,\n           dropout_add_mask, \n           outputs\n         };\n}\n\nstd::vector<torch::Tensor> bwd_cuda(\n                               int                  heads,\n                               torch::Tensor const& output_grads, \n                               torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results,\n                               torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results,\n                               torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs, \n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_gamma_weights,\n\t\t\t\t\t\t\t   torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask,\n                               float                dropout_prob\n                                   ) \n{\n  const int   embed_dim      = inputs.size(2);\n  const int   sequences      = inputs.size(1);\n  const int   q_seq_len      = inputs.size(0);\n  const int   k_seq_len      = q_seq_len;\n  const int   batches        = sequences * q_seq_len;\n  const int   total_tokens   = batches * embed_dim;\n  const int   head_dim       = embed_dim / heads;\n  const int   output_lin_dim = 3 * embed_dim;\n  const int   attn_batches   = heads * sequences;\n  const int   lead_dim       = attn_batches * 3 * head_dim;\n  const int   batch_stride   = 3 * head_dim;\n  const int   dropout_elems  = attn_batches * q_seq_len * k_seq_len;\n  const float alpha          = 1.0;\n  const float beta           = 0.0;\n  const float scale          = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n \n  // Output Tensor Allocations\n  torch::Tensor input_grads            = torch::empty_like(inputs);\n  torch::Tensor lyr_nrm_gamma_grads    = torch::empty_like(lyr_nrm_gamma_weights);\n  torch::Tensor lyr_nrm_beta_grads     = torch::empty_like(lyr_nrm_beta_weights);\n  torch::Tensor input_weight_grads     = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads    = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  torch::Tensor dropout_add_grads      = torch::empty_like(output_grads);\n  torch::Tensor output_lin_grads       = torch::empty_like(matmul2_results);\n  torch::Tensor matmul2_grads          = torch::empty_like(dropout_results);\n  torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n  torch::Tensor input_lin_grads        = torch::empty_like(inputs);\n \n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;\n  \n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'}; \n  \n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Dropout Add Backward  \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(output_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(dropout_add_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),\n       \t\t\t\t\t\t total_tokens,\n                             (1.0 / (1.0 - dropout_prob)));\n \n  // Output Linear Dgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim, \n                             batches, \n                             embed_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(output_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n \n  // Output Linear Wgrad\n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             embed_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(matmul2_results.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(dropout_add_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim, \n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(output_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_t, \n                             b_layout_n, \n                             k_seq_len,\n                             q_seq_len,\n                             head_dim,\n                             alpha, \n                             static_cast<const half*>(v_lin_results_ptr),\n                             lead_dim, \n                             batch_stride,\n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             beta, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len,\n                             attn_batches);\n  \n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             alpha, \n                             static_cast<const half*>(output_lin_grads.data_ptr()),\n                             head_dim*attn_batches, \n                             head_dim, \n                             static_cast<const half*>(dropout_results.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             v_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability \n  apex_masked_scale_cuda<at::Half,float,uint32_t>(\n                             static_cast<at::Half const*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<at::Half*>(matmul2_grads.data_ptr()),\n\t\t\t\t\t\t\t static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n\t\t\t\t\t\t\t dropout_elems,\n                             (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             static_cast<half*>(matmul2_grads.data_ptr()), \n                             reinterpret_cast<half const*>(softmax_results.data_ptr()),\n                             k_seq_len,\n                             k_seq_len,\n                             attn_batches*q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_n, \n                             head_dim, \n                             q_seq_len, \n                             k_seq_len, \n                             scale, \n                             k_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             q_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n  \n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(     state, \n                             a_layout_n, \n                             b_layout_t, \n                             head_dim, \n                             k_seq_len, \n                             q_seq_len, \n                             scale, \n                             q_lin_results_ptr, \n                             lead_dim, \n                             batch_stride, \n                             static_cast<half*>(matmul2_grads.data_ptr()),\n                             k_seq_len, \n                             k_seq_len*q_seq_len, \n                             beta, \n                             k_lin_grads_ptr, \n                             lead_dim, \n                             batch_stride, \n                             attn_batches);\n\n  // Input Linear Dgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_N,\n                             embed_dim,\n                             batches, \n                             output_lin_dim,\n                             static_cast<const void*>(&alpha),\n                             static_cast<const void*>(input_weights.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F, \n                             output_lin_dim, \n                             static_cast<const void*>(&beta),\n                             //static_cast<void*>(input_grads.data_ptr()),\n                             static_cast<void*>(input_lin_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             //CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  \n  // Input Linear Wgrad  \n  THCublasCheck(cublasGemmEx(handle,\n                             CUBLAS_OP_N, \n                             CUBLAS_OP_T,\n                             embed_dim, \n                             output_lin_dim,\n                             batches, \n                             static_cast<const void*>(&alpha),\n                             //static_cast<const void*>(inputs.data_ptr()),\n                             static_cast<const void*>(lyr_nrm_results.data_ptr()),\n                             CUDA_R_16F,\n                             embed_dim,\n                             static_cast<const void*>(q_lin_grads_ptr),\n                             CUDA_R_16F,\n                             output_lin_dim,\n                             static_cast<const void*>(&beta),\n                             static_cast<void*>(input_weight_grads.data_ptr()),\n                             CUDA_R_16F, \n                             embed_dim,\n                             CUDA_R_32F,\n                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Fused Layer Norm Bwd with Residual Add\n  HostLayerNormGradient<half,float>(\n                             static_cast<const half*>(input_lin_grads.data_ptr()),\n                             static_cast<half const*>(output_grads.data_ptr()), \n                             static_cast<const float*>(lyr_nrm_mean.data_ptr()),\n                             static_cast<const float*>(lyr_nrm_invvar.data_ptr()),\n                             inputs,\n                             static_cast<int>(batches),   // n1\n                             static_cast<int>(embed_dim), // n2\n                             static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),\n                             static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()),\n                             1.0e-5,\n                             static_cast<half*>(input_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),\n                             static_cast<half*>(lyr_nrm_beta_grads.data_ptr())\n                                   );\n\n  THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {\n\t\t   input_grads, \n           lyr_nrm_gamma_grads, \n           lyr_nrm_beta_grads, \n           input_weight_grads, \n           output_weight_grads\n         };\n}\n\n} // end namespace cublas_gemmex\n} // end namespace self_norm_add \n} // end namespace multihead_attn\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/softmax.h",
    "content": "#pragma once\n#include <ATen/CUDAGeneratorImpl.h>\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n#include <curand_kernel.h>\n#include \"philox.h\"\n \n#include <assert.h>\n#include <cfloat>\n#include <limits>\n#include <stdint.h>\n#include <cuda_fp16.h>\n#include <cmath>\n \nnamespace {\n    template <typename Datatype, int ELEMENTS_PER_LDG>\n    __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);\n \n    template <>\n    __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }\n \n    template <>\n    __device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }\n \n    template <>\n    __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } \n    template <>\n    __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }\n    \n    template <>\n    __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }\n   \n    template <typename Datatype, int ELEMENTS_PER_LDG>\n    __device__ __inline__ void apply_mask(Datatype *dst, Datatype value, const uint8_t *src);\n    \n    template <>\n    __device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, const uint8_t *src) {\n      if (*src == 1) { *dst = value; }\n    }\n    template <typename Datatype, int ELEMENTS_PER_LDG>\n    __device__ __inline__ void apply_additive_mask(Datatype *dst, const Datatype *additive_mask);\n    template <>\n    __device__ __inline__ void apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {\n      *dst += *additive_mask; \n    }\n    template <>\n    __device__ __inline__ void apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {\n      *dst += *additive_mask;\n      *(dst+1) += *(additive_mask+1);\n      *(dst+2) += *(additive_mask+2);\n      *(dst+3) += *(additive_mask+3);}    \n} // namespace anonymous\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp Softmax forward\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batch_size, int stride, int element_count)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n \n    src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n            }\n \n            if (element_index < batch_element_count) {\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + i * element_count + it * WARP_SIZE);\n            }\n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing softmax_forward_func = void(*)(input_t *dst, const output_t *src, int batch_size, int stride, int element_count);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>\n__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)\n{\n \n    assert(ELEMENTS_PER_LDG_STG==4);\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n    int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;\n    acc_t pinv = acc_t(1)/p;\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n     \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n    //vectorize if element_count is multiple of 4, else don't vectorize\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n    dropout_mask += thread_offset;\n    \n    // load data from global memory\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const half* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n    \t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n            }\n    \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()\n            } \n    \n        }\n    }\n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n    auto seeds = at::cuda::philox::unpack(philox_args);\n    Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));     \n    uint8_t rands[WARP_BATCH][WARP_ITERATIONS];\n    float4 rand_num;\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n\t#pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it+=ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n\t\trand_num = uniform4(ph());\n                rands[i][it] = (rand_num.x <= p) > 0.5;  \n                rands[i][it+1] = (rand_num.y <= p) > 0.5;\n                rands[i][it+2] = (rand_num.z <= p) > 0.5;\n                rands[i][it+3] = (rand_num.w <= p) > 0.5;\n                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);\n\t    }\n        }\n    }\n\n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                output_t out[ELEMENTS_PER_LDG_STG];\n                #pragma unroll\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = rands[i][it+element] * (pinv * (elements[i][it + element] / sum[i]));\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n    \n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>\n__global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n    int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;\n    acc_t pinv = acc_t(1)/p;\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n     \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n    //vectorize if element_count is multiple of 4, else don't vectorize\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n\n    int thread_offset =  first_batch * stride + local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n    dropout_mask += thread_offset;\n    \n    // load data from global memory\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + local_idx;\n        const half* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += 1) {\n            int element_index = local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < 1;++element) {\n    \t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n            }\n    \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, 1>(&elements_input[i][it], src + itr_idx);\n                apply_additive_mask<input_t, 1>(&elements_input[i][it], curr_mask + itr_jmp); \n            } \n    \n        }\n    }\n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n    curandStatePhilox4_32_10_t state;\n    auto seeds = at::cuda::philox::unpack(philox_args);\n    curand_init(\n      std::get<0>(seeds),\n      tid,\n      std::get<1>(seeds),\n      &state);\n     \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += 1) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                output_t out[1];\n                acc_t softmax_out[1];\n                uint8_t dropout_mask_temp[1];\n                //generate a vector of random numbers here \n                float rand = curand_uniform(&state);\n                float *rand_ptr = (float*)(&rand);    \n                #pragma unroll\n                for (int element = 0;element < 1;++element) {\n    \t        softmax_out[element] = (elements[i][it + element] / sum[i]);\t\n                    rand_ptr[element] = rand_ptr[element] <= p;       \n                    out[element] = rand_ptr[element] * pinv * softmax_out[element];\n    \t            dropout_mask_temp[element] = rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f \n                }\n                copy_vector<output_t, 1>(dst + i * element_count + it * WARP_SIZE, out);\n                copy_vector<uint8_t, 1>(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp);\n    \n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t>\nusing additive_masked_softmax_dropout_forward_func = void(*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride,  at::PhiloxCudaState philox_args, float p);\n\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n    bool flag_vec4 = (element_count % 4 == 0); \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n\tif (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2,4,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n\tif (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,8,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,16,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,32,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    case 11: // 2048\n        if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,64,32,4>;\n\telse kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,64,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n\n\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int totalElements, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, cudaStream_t streamid)// p is the probability to keep, not drop\n{\n\t\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 2048) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n\tc10::optional<at::Generator> gen_;\n        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());\n        int64_t counter_offset = (totalElements/(blocks*threads_per_block)+1);\n        at::PhiloxCudaState rng_engine_inputs;\n\t{\n          std::lock_guard<std::mutex> lock(gen->mutex_);\n\t  rng_engine_inputs = gen->philox_cuda_state(counter_offset);\n        }\n \n        // compute launch size\n        dim3 threads(warp_size, warps_per_block, 1);\n         \n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride, rng_engine_inputs, p);\n        return true;\n    }\n    return false;\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const half* curr_mask    = pad_mask + pad_thread_offset;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n\t\t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n            }\n \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                //apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], \n                //                                          (__half)-std::numeric_limits<float>::infinity(), \n                //                                          curr_mask + itr_jmp);\n                elements_input[i][it] += *(curr_mask + itr_jmp);\n\t    } \n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing additive_masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const half *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_additive_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        additive_masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax_stream(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, cudaStream_t streamid)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n        additive_masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n\n\n\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const uint8_t* curr_mask    = pad_mask + pad_thread_offset;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n            }\n \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], \n                                                          (__half)-std::numeric_limits<float>::infinity(), \n                                                          curr_mask + itr_jmp);\n            }\n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, masked_softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int mod_seq_len)\n{\n    assert(ELEMENTS_PER_LDG_STG==1);\n \n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    src += thread_offset;\n    dst += thread_offset;\n \n    // load data from global memory\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) % mod_seq_len) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const uint8_t* curr_mask    = pad_mask + pad_thread_offset;\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n            }\n \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n                apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], \n                                                          (__half)-std::numeric_limits<float>::infinity(), \n                                                          curr_mask + itr_jmp);\n            }\n \n        }\n    }\n \n    // convert input_t to acc_t\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                //dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = elements[i][it + element] / sum[i];\n                }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n            }\n            else {\n                break;\n            }\n        }\n    }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing time_masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int mod_seq_len);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_time_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, time_masked_softmax_forward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int mod_seq_len)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        time_masked_softmax_forward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_time_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len);\n        return true;\n    }\n    return false;\n}\n\nint log2_ceil_native(int value) {\n    int log2_value = 0;\n    while ((1 << log2_value) < value) ++log2_value;\n    return log2_value;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)\n{\n#if CUDA_VERSION >= 9000\n    return __shfl_xor_sync(mask, value, laneMask, width);\n#else\n    return __shfl_xor(value, laneMask, width);\n#endif\n}\n\ntemplate <typename acc_t, int WARP_BATCH, int WARP_SIZE>\n__device__ __forceinline__ void warp_reduce_sum(acc_t* sum) {\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;  i < WARP_BATCH;  ++i) {\n            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);\n            sum[i] = sum[i] + b;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp softmax backward functions as fused variants of at::softmax_backward_data function\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\n//softmax backward data function is taken from native pytorch, elementwise mul is fused in the epolog, as well as masking and scaling for fusing dropout\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size, int stride, int element_count, int heads)\n{\n    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.\n    constexpr int next_power_of_two = 1 << log2_elements;\n    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n\n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n    mask += thread_offset;\n\n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n\n    // load data from global memory\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];\n                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];\n            } else {\n                grad_reg[i][it] = acc_t(0);\n                output_reg[i][it] = acc_t(0);\n            }\n        }\n    }\n\n    acc_t sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        sum[i] = grad_reg[i][0]; \n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            sum[i] += grad_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n\t\tint total_ind = thread_offset + i*element_count + it*WARP_SIZE;\n\t\tint pad_mask_ind =  element_count*(total_ind/(heads * element_count * element_count)) + total_ind%element_count;\n\t\tuint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind];\n\t\tif (pad_mask_element == 0) gradInput[i*element_count+it*WARP_SIZE] = 0;\n\t\telse {\n                  if (is_log_softmax) {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n                  } else {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n                  }\n\t\t}\n            }\n        }\n    }\n}\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 1: // 2\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 2: // 4\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 3: // 8\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 4: // 16\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 5: // 32\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 6: // 64\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 7: // 128\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 8: // 256\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 9: // 512\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 10: // 1024\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads, cudaStream_t streamid)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 1: // 2\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 2: // 4\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 3: // 8\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 4: // 16\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 5: // 32\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 6: // 64\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 7: // 128\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 8: // 256\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 9: // 512\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            case 10: // 1024\n                masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count)\n{\n    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.\n    constexpr int next_power_of_two = 1 << log2_elements;\n    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n\n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n    mask += thread_offset;\n\n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n\n    // load data from global memory\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];\n                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];\n            } else {\n                grad_reg[i][it] = acc_t(0);\n                output_reg[i][it] = acc_t(0);\n            }\n        }\n    }\n\n    acc_t sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        sum[i] = grad_reg[i][0]; \n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            sum[i] += grad_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                if (is_log_softmax) {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n                } else {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n                }\n            }\n        }\n    }\n}\n\n\n\n\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count)\n{\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n    //vectorize if a row length is multiple of 4\n    int flag_vec4 = element_count & 3 == 0;\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    input_t elements_input[WARP_BATCH][WARP_ITERATIONS] ;\n\n    // the first element to process by the current thread\n    int thread_offset =  first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    \n    grad += thread_offset;\n    softmax_input += thread_offset;\n    gradInput += thread_offset;\n    mask += thread_offset;\n    \n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n    \n    // load data from global memory\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const input_t* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n    \n            #pragma unroll\n            for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n    \t//masking_value is a large negative value\n                elements_input[i][it + element] = -10000;\n    \t        grad_reg[i][it+element] = acc_t(0);\n            }\n    \n            if (element_index < batch_element_count) {\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], softmax_input + itr_idx);\n                apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()\n                uint8_t mask_temp[ELEMENTS_PER_LDG_STG];\n                input_t grad_temp[ELEMENTS_PER_LDG_STG];\n                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0], mask + itr_idx);\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0], grad + itr_idx);\n                #pragma unroll\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    grad_reg[i][it+element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale );\n                }\n            } \n    \n        }\n    }\n    // load data from global memory\n \n    // convert input_t to acc_t\n    // TODO : remove this, input is already acc_t type in register\n    acc_t elements[WARP_BATCH][WARP_ITERATIONS] ;\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            elements[i][it] = elements_input[i][it];\n        }\n    }\n \n    constexpr uint32_t  FULL_MASK = 0xffffffff;\n \n    // compute local max_value\n \n    // take the max_value of the first element to avoid one max call\n    acc_t max_value[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        max_value[i] = elements[i][0];\n    }\n \n    #pragma unroll\n    for (int it = 1;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n        }\n    }\n \n    // reduction max_value\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        float val[WARP_BATCH];\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n        }\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n        }\n    }\n \n    // compute local sum\n    acc_t sum[WARP_BATCH] { 0.0f };\n \n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            //elements[i][it] = expf(elements[i][it] - max_value[i]);\n            elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n            sum[i] += elements[i][it];\n        }\n    }\n \n    // reduction sum\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n\n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it ++) {\n\t   elements[i][it] = elements[i][it] / sum[i]; \n           grad_reg[i][it] = grad_reg[i][it] * elements[i][it];\n\t}\n    }\n\n    acc_t grad_sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        grad_sum[i] = grad_reg[i][0]; \n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            grad_sum[i] += grad_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(grad_sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n\t            output_t grad_input_reg[ELEMENTS_PER_LDG_STG];\n                #pragma unroll\n\t            for (int element=0; element<ELEMENTS_PER_LDG_STG; element++) {\n                    if (is_log_softmax) {\n                        grad_input_reg[element] = (grad_reg[i][it+element] - std::exp(elements[i][it+element]) * grad_sum[i]);\n                    } else {\n                        grad_input_reg[element] = (grad_reg[i][it+element] - elements[i][it+element] * grad_sum[i]);\n                    }\n\t             \n\t            }\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);\n            }\n        }\n    }\n}\n\n\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nusing masked_scale_softmax_warp_backward_recompute_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count);\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nbool masked_scale_softmax_warp_backward_recompute_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n    bool flag_vec4 = (element_count % 4 == 0); \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,1,1, is_log_softmax>;\n        break;\n    case 1: // 2\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,2,1, is_log_softmax>;\n        break;\n    case 2: // 4\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,4,1, is_log_softmax>;\n        break;\n    case 3: // 8\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,8,1, is_log_softmax>;\n        break;\n    case 4: // 16\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,16,1, is_log_softmax>;\n        break;\n    case 5: // 32\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,32,1, is_log_softmax>;\n        break;\n    case 6: // 64\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,2,32,1, is_log_softmax>;\n        break;\n    case 7: // 128\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,4,32,1, is_log_softmax>;\n        break;\n    case 8: // 256\n\tif (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,1, is_log_softmax>;\n        break;\n    case 9: // 512\n        if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,1, is_log_softmax>;\n        break;\n    case 10: // 1024\n        if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,1, is_log_softmax>;\n        break;\n    case 11: // 2048\n        if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,4, is_log_softmax>;\n\telse kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,1, is_log_softmax>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nbool dispatch_masked_scale_softmax_backward_recompute(output_t *grad_input, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int pad_batch_stride, int batch_count, cudaStream_t streamid)\n{\n\t\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 2048) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> kernel;\n        int warp_size, batches_per_warp;\n        if (!masked_scale_softmax_warp_backward_recompute_kernel<input_t, output_t, acc_t, is_log_softmax>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n\n        // compute launch size\n        dim3 threads(warp_size, warps_per_block, 1);\n         \n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, softmax_elements_stride, pad_batch_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 1: // 2\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 2: // 4\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 3: // 8\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 4: // 16\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 5: // 32\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 6: // 64\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 7: // 128\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 8: // 256\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 9: // 512\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 10: // 1024\n                masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\n// elementwise multiplication called in at::softmax_backward_data is fused inside softmax dgrad kernel\n// as a result of fusion, intermediate multiplication result is stored in fp32 in registers, instead of fp16\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count)\n{\n    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.\n    constexpr int next_power_of_two = 1 << log2_elements;\n    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n\n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x % WARP_SIZE;\n\n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n\n    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,\n    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep\n    // the nested loops.\n    // This should have no impact on performance because the loops are unrolled anyway.\n\n    // load data from global memory\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]  ;\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]*output[i*element_count+it*WARP_SIZE];\n                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];\n            } else {\n                grad_reg[i][it] = acc_t(0);\n                output_reg[i][it] = acc_t(0);\n            }\n        }\n    }\n\n    acc_t sum[WARP_BATCH];\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        sum[i] = grad_reg[i][0]; //* output_reg[i][0];\n        #pragma unroll\n        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {\n            sum[i] += grad_reg[i][it];// * output_reg[i][it];\n        }\n    }\n    warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n    // store result\n    #pragma unroll\n    for (int i = 0;  i < WARP_BATCH;  ++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {\n            int element_index = local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                if (is_log_softmax) {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n                } else {\n                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n                }\n            }\n        }\n    }\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_softmax_backward_fused_native(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count)\n{\n    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );\n    if (softmax_elements == 0) {\n       return;\n    } else {\n        int log2_elements = log2_ceil_native(softmax_elements);\n        const int next_power_of_two = 1 << log2_elements;\n\n        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n\n        int warps_per_block = (threads_per_block / warp_size);\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n        switch (log2_elements) {\n            case 0: // 1\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 0, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 1: // 2\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 1, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 2: // 4\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 2, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 3: // 8\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 3, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 4: // 16\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 4, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 5: // 32\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 5, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 6: // 64\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 6, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 7: // 128\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 7, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 8: // 256\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 8, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 9: // 512\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 9, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            case 10: // 1024\n                softmax_warp_backward_fused_native<input_t, output_t, acc_t, 10, is_log_softmax>\n                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n                break;\n            default:\n                break;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp softmax backward\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, int batch_size, int stride, int element_count)\n{\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n \n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n \n    // load data from global memory\n    input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);\n                copy_vector<input_t,ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE);\n            }\n \n        }\n    }\n \n    // convert half to floating point\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            grad_reg[i][it] = grad_reg_input[i][it];\n            output_reg[i][it] = output_reg_input[i][it];\n        }\n    }\n \n \n    // compute thread local sum\n    acc_t sum[WARP_BATCH] = {0};\n    #pragma unroll\n    for (int it = 0;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += grad_reg[i][it] * output_reg[i][it];\n \n        }\n    }\n \n    // reduction sum\n    constexpr uint32_t FULL_MASK = 0xffffffff;\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = (output_reg[i][it+element] * (grad_reg[i][it+element] - sum[i]));\n                }\n                // store them in global memory\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);\n            }\n        }\n    }\n}\n \n \n \n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing softmax_backward_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, softmax_backward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        softmax_backward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n        softmax_backward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n        // launch\n        kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);\n        return true;\n    }\n    return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG=1>\n__global__ void masked_softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)\n{\n    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n \n    // batch_size might not be a multiple of WARP_BATCH. Check how\n    // many batches have to computed within this WARP.\n    int local_batches = batch_size - first_batch;\n    if (local_batches > WARP_BATCH)\n        local_batches = WARP_BATCH;\n \n    // there might be multiple batches per warp. compute the index within the batch\n    int local_idx = threadIdx.x;\n \n    // the first element to process by the current thread\n    int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    grad += thread_offset;\n    output += thread_offset;\n    gradInput += thread_offset;\n \n    // load data from global memory\n    input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        int batch_element_count = (i >= local_batches) ? 0 : element_count;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < batch_element_count) {\n                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);\n                copy_vector<input_t,ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE);\n            }\n \n        }\n    }\n \n    // convert half to floating point\n    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        for (int it = 0;it < WARP_ITERATIONS;++it) {\n            grad_reg[i][it] = grad_reg_input[i][it];\n            output_reg[i][it] = output_reg_input[i][it];\n        }\n    }\n \n \n    // compute thread local sum\n    acc_t sum[WARP_BATCH] = {0};\n    #pragma unroll\n    for (int it = 0;it < WARP_ITERATIONS;++it) {\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += grad_reg[i][it] * output_reg[i][it];\n \n        }\n    }\n \n    // reduction sum\n    constexpr uint32_t FULL_MASK = 0xffffffff;\n    #pragma unroll\n    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n        #pragma unroll\n        for (int i = 0;i < WARP_BATCH;++i) {\n            sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n        }\n    }\n \n    // store result\n    #pragma unroll\n    for (int i = 0;i < WARP_BATCH;++i) {\n        if (i >= local_batches)\n            break;\n        int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n        const uint8_t* curr_mask    = pad_mask + pad_thread_offset;\n        #pragma unroll\n        for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {\n            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n            if (element_index < element_count) {\n                // compute gradients\n                output_t out[ELEMENTS_PER_LDG_STG];\n                for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {\n                    out[element] = (output_reg[i][it+element] * (grad_reg[i][it+element] - sum[i]));\n                }\n                // store them in global memory\n                int itr_jmp = it * WARP_SIZE;\n                int itr_idx = i * element_count + itr_jmp;\n                // It is kind of unfortunate this has to be here to zero something out that is close to\n                // zero in the first place\n                apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&out[0], 0.0, curr_mask + itr_jmp);\n                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + itr_idx, out);\n            }\n        }\n    }\n}\n \n \n \n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.\n// WARP_SIZE number of elements working on a single batch, has to be a power of two.\n// ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing masked_softmax_backward_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);\n \ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_masked_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, masked_softmax_backward_func<input_t, output_t> &kernel) {\n    // determine size of a warp\n    const int next_power_of_two = 1 << log2_elements;\n    warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n \n    // determine how many batches a warp should process.\n    batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n \n    switch (log2_elements) {\n    case 0: // 1\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,1,1>;\n        break;\n    case 1: // 2\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,2,1>;\n        break;\n    case 2: // 4\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,4,1>;\n        break;\n    case 3: // 8\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,8,1>;\n        break;\n    case 4: // 16\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,16,1>;\n        break;\n    case 5: // 32\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,32,1>;\n        break;\n    case 6: // 64\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,2,32,1>;\n        break;\n    case 7: // 128\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,4,32,1>;\n        break;\n    case 8: // 256\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,8,32,1>;\n        break;\n    case 9: // 512\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,16,32,1>;\n        break;\n    case 10: // 1024\n        kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,32,32,1>;\n        break;\n    default:\n        return false;\n    }\n    return true;\n}\n \ntemplate<typename input_t, typename output_t, typename acc_t>\nbool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)\n{\n    if (softmax_elements == 0) {\n        return true;\n    } else if (softmax_elements <= 1024) {\n        // compute function index. there's a function for each power of two size up to 1024.\n        int log2_elements = 0;\n        while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n \n        masked_softmax_backward_func<input_t, output_t> kernel;\n        int warp_size, batches_per_warp;\n        if (!warp_masked_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n            return false;\n        }\n \n        // use 128 threads per block to maximimize gpu utilization\n        constexpr int threads_per_block = 128;\n \n        // compute warps per block.\n        int warps_per_block = (threads_per_block / warp_size);\n \n        // compute launch size\n        int batches_per_block = warps_per_block * batches_per_warp;\n        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n        dim3 threads(warp_size, warps_per_block, 1);\n \n        // launch\n        kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n        return true;\n    }\n    return false;\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/multihead_attn/strided_batched_gemm.h",
    "content": "#include <vector>\n#include <iostream>\n\n//#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include \"THC/THC.h\"\n#include <ATen/cuda/CUDAContext.h>\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/gemm/gemm.h\"\n#include \"cutlass/gemm/wmma_gemm_traits.h\"\n\n// symbol to be automatically resolved by PyTorch libs\nextern THCState *state;\n\ncublasOperation_t convertTransToCublasOperation(char trans) {\n  if (trans == 't') return CUBLAS_OP_T;\n  else if (trans == 'n') return CUBLAS_OP_N;\n  else if (trans == 'c') return CUBLAS_OP_C;\n  else {\n    THError(\"trans must be one of: t, n, c\");\n    return CUBLAS_OP_T;\n  }\n}\n\nvoid CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, long n, long k,\n                    float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                    float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) {\n    cublasOperation_t opa = convertTransToCublasOperation(transa);\n    cublasOperation_t opb = convertTransToCublasOperation(transb);\n\n    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n    cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();\n    cublasSetStream(handle, stream);\n    float fAlpha = alpha;\n    float fBeta = beta;\n    //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n    THCublasCheck(cublasGemmStridedBatchedEx(handle,\n                                     opa, opb, (int)m, (int)n, (int)k,\n                                     (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,\n                                     b, CUDA_R_16F, (int)ldb, strideB,\n                                     (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,\n                                     (int)batchCount, CUDA_R_32F, algo));\n    //THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n}\n\ntemplate<cutlass::MatrixLayout::Kind A_LAYOUT, cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>\nvoid CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,\n                          float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                          float beta, half *c, long ldc, long strideC, long batchCount) {\n  //printf(\"CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\\n\", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);\n  typedef cutlass::gemm::WmmaGemmTraits<\n    A_LAYOUT,\n    B_LAYOUT,\n    cutlass::Shape<32, 16, 16>,\n    half,\n    half,\n    half,\n    cutlass::gemm::LinearScaling<float>,\n    float,\n    typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,\n    typename cutlass::Shape<16, 16, 16>,\n    SRC_A,   //kScalarsPerLdgA_\n    SRC_B,   //kScalarsPerLdgB_\n    SRC_A,   //KScalarsPerLdsA_\n    SRC_B,   //KScalarsPerLdsB_\n    DST_C,   //kScalarsPerLdgCAndStgD_\n    DST_C/2, //kScalarsPerStsD_\n    DST_C/2  //kScalarsPerLdsD_\n  >\n    WmmaGemmTraits;\n\n  typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm;\n  typename Gemm::Params params;\n\n\n  int result = params.initialize(\n    m,                  // M dimension for each batch\n    n,                  // N dimension for each batch\n    k,                  // K dimension for each batch\n    alpha,              // scalar alpha\n    a,\n    lda,\n    strideA,     // distance in memory between the first element of neighboring batch\n    b,\n    ldb,\n    strideB,     // distance in memory between the first element of neighboring batch\n    beta,               // scalar beta\n    c,                  // source matrix C\n    ldc,\n    strideC,     // distance in memory between the first element of neighboring batch\n    c,                  // destination matrix C (may be different memory than source C matrix)\n    ldc,\n    strideC,    // distance in memory between the first element of neighboring batch\n    batchCount\n  );\n\n  AT_ASSERTM(result == 0, \"Failed to initialize CUTLASS Gemm::Params object.\");\n  \n  // batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits. \n  // To implement batched GEMM with larger batch size, we fragment it into\n  // smaller batched GEMMs of gridDim.z <= 64k\n  long batchesLeft    = batchCount;\n  long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));\n  \n  do {\n  \t //printf(\"CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\\n\", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);\n    int result = params.initialize(\n      m,                  // M dimension for each batch\n      n,                  // N dimension for each batch\n      k,                  // K dimension for each batch\n      alpha,              // scalar alpha\n      a,\n      lda,\n      strideA,     // distance in memory between the first element of neighboring batch\n      b,\n      ldb,\n      strideB,     // distance in memory between the first element of neighboring batch\n      beta,               // scalar beta\n      c,                  // source matrix C\n      ldc,\n      strideC,     // distance in memory between the first element of neighboring batch\n      c,                  // destination matrix C (may be different memory than source C matrix)\n      ldc,\n      strideC,    // distance in memory between the first element of neighboring batch\n      iterBatchCount\n    );\n\n    AT_ASSERTM(result == 0, \"Failed to initialize CUTLASS Gemm::Params object.\");\n    // Launch the CUTLASS GEMM kernel.\n    THCudaCheck(Gemm::launch(params, stream));\n\n    // Update batched GEMM params based on completed work\n    batchesLeft = batchesLeft - iterBatchCount;\n    a += iterBatchCount * strideA;\n    b += iterBatchCount * strideB;\n    c += iterBatchCount * strideC;;\n\n    iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));\n    \n  } while(batchesLeft > 0);\n}\n\nvoid gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,\n                           float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                           float beta, half *c, long ldc, long strideC, long batchCount) {\n  auto stream = c10::cuda::getCurrentCUDAStream();\n  //printf(\"GEMM   -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\\n\", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);\n  if        ( (transa == 't') && (transb == 'n') ) { \n    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }\n    /*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      int m_rem = m % 64;\n      int n_rem = n % 64;\n      if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else {\n        CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); \n      }\n    }*/\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else                                                   { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n  } else if ( (transa == 'n') && (transb == 'n') ) {\n    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }\n    /*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      int m_rem = m % 64;\n      int n_rem = n % 64;\n      if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n      } else {\n        CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n      }\n    }*/\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else                                                   { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n  } else if ( (transa == 'n') && (transb == 't') ) {\n    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }\n    /*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { \n      int m_rem = m % 64;\n      int n_rem = n % 64;\n      if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); \n      } else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {\n        CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); \n      } else {\n        CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); \n      }\n    }*/\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n    else                                                   { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }\n  } else {\n    AT_ASSERTM(false, \"TransA and TransB are invalid\");\n  }\n}\n\nvoid adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc)\n{\n  int transa_ = ((transa == 't') || (transa == 'T'));\n  int transb_ = ((transb == 't') || (transb == 'T'));\n\n  // Note: leading dimensions generally are checked that they are > 0 and at least as big the result\n  // requires (even if the value won't be used).\n  if(n <= 1)\n    *ldc = std::max<int64_t>(m, 1);\n\n  if(transa_)\n  {\n    if(m <= 1)\n      *lda = std::max<int64_t>(k, 1);\n  }\n  else\n  {\n    if(k <= 1)\n      *lda = std::max<int64_t>(m, 1);\n  }\n\n  if(transb_)\n  {\n    if(k <= 1)\n      *ldb = std::max<int64_t>(n, 1);\n  }\n  else\n  {\n    if(n <= 1)\n      *ldb = std::max<int64_t>(k, 1);\n  }\n\n}\n\nvoid HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,\n                             float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,\n                             float beta, half *c, long ldc, long strideC, long batchCount)\n{\n  if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX)  || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )\n\n  {\n    THError(\"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount\"\n            \"with the bound [val] <= %d\", INT_MAX);\n  }\n\n  adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);\n\n  //gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n  gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n}\n\n/******\nat::Tensor strided_batched_gemm_cuda(\n    float beta,\n    at::Tensor in_result,\n    float alpha,\n    at::Tensor batch1,\n    at::Tensor batch2) {\n\n  bool transpose_result;\n  char transpose_batch1, transpose_batch2;\n  int64_t lda, ldb, ldc;\n  at::Tensor result, input1, input2;\n  if (in_result.stride(1) == 1)\n  {\n    transpose_result = false;\n    result = in_result;\n    ldc = result.stride(2);\n  }\n  else if (in_result.stride(2) == 1)\n  {\n    transpose_result = true;\n\n    at::Tensor swap = batch2;\n    batch2 = batch1;\n    batch1 = swap;\n\n    result = in_result;\n    ldc = result.stride(1);\n  } else { \n    AT_ASSERTM(false, \"result should be contiguous\");\n  }\n\n  if (batch1.stride(transpose_result ? 2 : 1) == 1 &&\n      batch1.stride(transpose_result ? 1 : 2) != 0) {\n    transpose_batch1 = 'n';\n    input1 = batch1;\n    lda = input1.stride(transpose_result ? 1 : 2);\n  } else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&\n             batch1.stride(transpose_result ? 2 : 1) != 0) {\n    transpose_batch1 = 't';\n    input1 = batch1;\n    lda = input1.stride(transpose_result ? 2 : 1);\n  } else {\n    AT_ASSERTM(false, \"input1 should be contiguous\");\n  }\n\n  if (batch2.stride(transpose_result ? 2 : 1) == 1 &&\n      batch2.stride(transpose_result ? 1 : 2) != 0) {\n    transpose_batch2 = 'n';\n    input2 = batch2;\n    ldb = input2.stride(transpose_result ? 1 : 2);\n  } else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&\n             batch2.stride(transpose_result ? 2 : 1) != 0) {\n    transpose_batch2 = 't';\n    input2 = batch2;\n    ldb = input2.stride(transpose_result ? 2 : 1);\n  } else {\n    AT_ASSERTM(false, \"input2 should be contiguous\");\n  }\n  int64_t num_batches = result.size(0);\n\n  HgemmStridedBatched(\n      state,\n      transpose_batch1,\n      transpose_batch2,\n      result.size(transpose_result ? 2 : 1),\n      result.size(transpose_result ? 1 : 2),\n      input1.size(transpose_result ? 1 : 2),\n      alpha,\n      static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),\n      static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),\n      beta,\n      static_cast<half*>(result.data_ptr()), ldc, result.stride(0),\n      num_batches);\n\n  return in_result;\n}\n\n***/\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp",
    "content": "#include <torch/extension.h>\n\n// CUDA forward declaration\nvoid fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);\n\nvoid fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\nvoid fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\nvoid fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\n\nvoid fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);\n\nvoid maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);\nvoid maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\n// C++ interface\nvoid strided_check_finite(\n\t\tat::Tensor& overflow_flag,\n\t\tat::Tensor& p_copy,\n\t\tint stride,\n\t\tint clear_overflow_first\n\t ) {\n\tCHECK_INPUT(p_copy);\n\tfused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);\n}\nvoid adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {\n        CHECK_INPUT(p);\n        if (p_copy.numel() > 0) CHECK_INPUT(p_copy);\n        CHECK_INPUT(m);\n        CHECK_INPUT(v);\n        CHECK_INPUT(g);\n        int64_t num_elem = p.numel();\n        AT_ASSERTM(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n        AT_ASSERTM(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n        AT_ASSERTM(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n        AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, \"number of elements in p_copy and p tensors should be equal, or p_copy should be empty\");\n\n        fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid reversible_adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {\n        CHECK_INPUT(p);\n        if (p_copy.numel() > 0) CHECK_INPUT(p_copy);\n        CHECK_INPUT(m);\n        CHECK_INPUT(v);\n        CHECK_INPUT(g);\n        int64_t num_elem = p.numel();\n        AT_ASSERTM(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n        AT_ASSERTM(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n        AT_ASSERTM(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n        AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, \"number of elements in p_copy and p tensors should be equal, or p_copy should be empty\");\n\n        fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {\n        CHECK_INPUT(p);\n        CHECK_INPUT(m);\n        CHECK_INPUT(v);\n        CHECK_INPUT(g);\n        int64_t num_elem = p.numel();\n        AT_ASSERTM(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n        AT_ASSERTM(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n        AT_ASSERTM(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n\n        fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) {\n\tCHECK_INPUT(p_in);\n\tCHECK_INPUT(p_out);\n\tint64_t num_elem = p_in.numel();\n\tAT_ASSERTM(p_out.numel() == num_elem, \"number of elements in p_in and p_out should be equal\");\n\n\tmaybe_cast_cuda(overflow_flag, p_in, p_out);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n        m.def(\"strided_check_finite\", &strided_check_finite, \"Strided finite check.\");\n        m.def(\"adam\", &adam, \"Adam optimized CUDA implementation.\");\n        m.def(\"reversible_adam\", &reversible_adam, \"Reversible Adam optimized CUDA implementation.\");\n        m.def(\"adam_mt\", &fused_adam_cuda_mt, \"Multi tensor Adam optimized CUDA implementation.\");\n        m.def(\"maybe_adam_undo\", &maybe_adam_undo, \"Undo function for Adam optimized CUDA implementation.\");\n        m.def(\"maybe_cast\", &maybe_cast, \"Unpack byte tensor containing e5m2 floats.\");\n        m.def(\"maybe_cast_mt\", &maybe_cast_cuda_mt, \"Unpack byte tensor containing e5m2 floats.\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu",
    "content": "#include \"ATen/ATen.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n#include \"ATen/cuda/detail/IndexUtils.cuh\"\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <stdio.h>\n#include <cmath>\n#include \"ATen/TensorUtils.h\"\n// #include \"ATen/Type.h\"\n#include \"ATen/AccumulateType.h\"\n#include <THC/THCGeneral.h>\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate<typename T>\n__device__ __forceinline__ bool is_aligned(T* p){\n  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\n#include \"type_shim.h\"\n\ntypedef enum{\n    ADAM_MODE_0   =0, // eps under square root\n    ADAM_MODE_1   =1  // eps outside square root\n} adamMode_t;\n\ntemplate <typename T, typename GRAD_T>\n__global__ void adam_cuda_kernel(\n        T* __restrict__ p,\n        GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed\n        T* __restrict__ m,\n        T* __restrict__ v,\n        const GRAD_T * __restrict__ g,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        const size_t tsize,\n        adamMode_t mode,\n        const float decay)\n{\n        //Assuming 2D grids and 2D blocks\n        const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n        const int threadsPerBlock = blockDim.x * blockDim.y;\n        const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n        const int i = (blockId * threadsPerBlock + threadIdInBlock);\n        const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n        for (int j = i; j < tsize; j+=totThreads) {\n                T scaled_grad = g[j]/grad_scale;\n                m[j] = b1*m[j] + (1-b1)*scaled_grad;\n                v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;\n                float denom;\n                if (mode == ADAM_MODE_0)\n                    denom = sqrtf(v[j] + eps);\n                else // Mode 1\n                    denom = sqrtf(v[j]) + eps;\n                float update = (m[j]/denom) + (decay*p[j]);\n                p[j] = p[j] - (step_size*update);\n                if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];\n        }\n}\n\ntemplate <int DEPTH, typename T, typename GRAD_T>\nstruct AdamFunctor\n{\n    __device__ __forceinline__ void operator()(\n        int chunk_size,\n        volatile int* noop_gmem,\n        TensorListMetadata<DEPTH>& tl,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        adamMode_t mode,\n        const float decay)\n    {\n        int tensor_loc = tl.block_to_tensor[blockIdx.x];\n        int chunk_idx = tl.block_to_chunk[blockIdx.x];\n        int n = tl.sizes[tensor_loc];\n\n        T* p = (T *)tl.addresses[0][tensor_loc];\n        p += chunk_idx*chunk_size;\n        T* m = (T *)tl.addresses[1][tensor_loc];\n        m += chunk_idx*chunk_size;\n        T* v = (T *)tl.addresses[2][tensor_loc];\n        v += chunk_idx*chunk_size;\n        GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];\n        g += chunk_idx*chunk_size;\n        GRAD_T* p_copy = NULL;\n        if (DEPTH == 5) {\n            p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];\n            p_copy += chunk_idx*chunk_size;\n        }\n\n        n -= chunk_idx*chunk_size;\n\n        T incoming_p[ILP];\n        T incoming_m[ILP];\n        T incoming_v[ILP];\n        T incoming_g[ILP];\n\n        // to make things simple, we put aligned case in a different code path\n        if(n % ILP == 0 &&\n           chunk_size % ILP == 0 &&\n           is_aligned(p) &&\n           is_aligned(m) &&\n           is_aligned(v) &&\n           is_aligned(g) &&\n           is_aligned(p_copy))\n        {\n          for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)\n          {\n            // load\n            GRAD_T tmp_g[ILP];\n            load_store(incoming_p, p, 0, i_start);\n            load_store(incoming_m, m, 0, i_start);\n            load_store(incoming_v, v, 0, i_start);\n            load_store(tmp_g, g, 0, i_start);\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n              incoming_g[ii] = static_cast<T>(tmp_g[ii]);\n              T scaled_grad = incoming_g[ii]/grad_scale;\n              incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n              incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n              float denom;\n              if (mode == ADAM_MODE_0)\n                denom = sqrtf(incoming_v[ii] + eps);\n              else // Mode 1\n                denom = sqrtf(incoming_v[ii]) + eps;\n              float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]);\n              incoming_p[ii] = incoming_p[ii] - (step_size*update);\n              if (DEPTH == 5)  tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);\n            }\n            load_store(p, incoming_p, i_start, 0);\n            load_store(m, incoming_m, i_start, 0);\n            load_store(v, incoming_v, i_start, 0);\n            if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);\n          }\n        }\n        else\n        {\n          for(int i_start = 0;\n              i_start < n && i_start < chunk_size;\n              i_start += blockDim.x*ILP) {\n\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n              incoming_p[ii] = 0;\n              incoming_m[ii] = 0;\n              incoming_v[ii] = 0;\n              incoming_g[ii] = 0;\n\n              int i = i_start + threadIdx.x + ii*blockDim.x;\n              if (i < n && i < chunk_size) {\n                incoming_p[ii] = p[i];\n                incoming_m[ii] = m[i];\n                incoming_v[ii] = v[i];\n                incoming_g[ii] = static_cast<T>(g[i]);\n              }\n            }\n\n            // note for clarification to future michael:\n            // From a pure memory dependency perspective, there's likely no point unrolling\n            // the write loop, since writes just fire off once their LDGs arrive.\n            // Put another way, the STGs are dependent on the LDGs, but not on each other.\n            // There is still compute ILP benefit from unrolling the loop though.\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n              int j = i_start + threadIdx.x + ii*blockDim.x;\n\n              if(j < n && j < chunk_size) {\n                T scaled_grad = incoming_g[ii]/grad_scale;\n                m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n                v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n                float denom;\n                if (mode == ADAM_MODE_0)\n                  denom = sqrtf(v[j] + eps);\n                else // Mode 1\n                  denom = sqrtf(v[j]) + eps;\n                float update = (m[j]/denom) + (decay*incoming_p[ii]);\n                p[j] = incoming_p[ii] - (step_size*update);\n                if (DEPTH == 5)  p_copy[j] = (GRAD_T) p[j];\n              }\n            }\n          }\n        }\n    }\n};\n\nvoid fused_adam_cuda(\n        at::Tensor & p,\n        at::Tensor & p_copy,\n        at::Tensor & m,\n        at::Tensor & v,\n        at::Tensor & g,\n        float lr,\n        float beta1,\n        float beta2,\n        float eps,\n        float grad_scale,\n        int step,\n        int mode,\n        int bias_correction,\n        float decay)\n{\n//        using namespace at;\n\n        //Get tensor size\n        int tsize = p.numel();\n        //Determine #threads and #blocks\n        const int threadsPerBlock = 512;\n        const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n        AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n        //Constants\n        float step_size = 0;\n        if (bias_correction == 1) {\n            const float bias_correction1 = 1 - std::pow(beta1, step);\n            const float bias_correction2 = 1 - std::pow(beta2, step);\n            step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n        }\n        else {\n            step_size = lr;\n        }\n        cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n        if (g.scalar_type() == at::ScalarType::Half) {\n//all other values should be fp32 for half gradients\n            AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n//dispatch is done on the gradient type\n            using namespace at; // prevents \"toString is undefined\" errors\n            DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                        p.DATA_PTR<accscalar_t>(),\n                        p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,\n                        m.DATA_PTR<accscalar_t>(),\n                        v.DATA_PTR<accscalar_t>(),\n                        g.DATA_PTR<scalar_t_0>(),\n                        beta1,\n                        beta2,\n                        eps,\n                        grad_scale,\n                        step_size,\n                        tsize,\n                        (adamMode_t) mode,\n                        decay);\n                );\n      } else {\n            using namespace at;\n            DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                        p.DATA_PTR<scalar_t_0>(),\n                        NULL, //don't output p_copy for fp32, it's wasted write\n                        m.DATA_PTR<scalar_t_0>(),\n                        v.DATA_PTR<scalar_t_0>(),\n                        g.DATA_PTR<scalar_t_0>(),\n                        beta1,\n                        beta2,\n                        eps,\n                        grad_scale,\n                        step_size,\n                        tsize,\n                        (adamMode_t) mode,\n                        decay);\n            );\n      }\n      THCudaCheck(cudaGetLastError());\n\n}\n\nvoid fused_adam_cuda_mt(\n    int chunk_size,\n    at::Tensor noop_flag,\n    std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy\n    float lr,\n    float beta1,\n    float beta2,\n    float eps,\n    float grad_scale,\n    int step,\n    int mode,\n    int bias_correction,\n    float decay) {\n\n    //Constants\n    float step_size = 0;\n    if (bias_correction == 1) {\n        const float bias_correction1 = 1 - std::pow(beta1, step);\n        const float bias_correction2 = 1 - std::pow(beta2, step);\n        step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n    }\n    else {\n        step_size = lr;\n    }\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    size_t tl_sz = tensor_lists.size();\n    AT_ASSERTM(tl_sz == 4 || tl_sz == 5, \"expected tensor lists of size 4 or 5\");\n\n    if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {\n//alher values should be fp32 for half gradients\n        AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n//dich is done on the gradient type\n        if (tl_sz == 5) {\n            DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                multi_tensor_apply<5>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<5, accscalar_t, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        } else {\n            DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                multi_tensor_apply<4>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<4, accscalar_t, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        }\n    } else {\n        if (tl_sz == 5) {\n            DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                multi_tensor_apply<5>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<5, scalar_t_0, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        } else {\n            DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                multi_tensor_apply<4>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    noop_flag,\n                    tensor_lists,\n                    AdamFunctor<4, scalar_t_0, scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    (adamMode_t) mode,\n                    decay);\n            );\n        }\n    }\n    THCudaCheck(cudaGetLastError());\n}\n\ntemplate <typename FROM_T, typename TO_T> \n__device__ void convert(const FROM_T vi, TO_T& vo)\n{\n    vo = static_cast<TO_T>(vi);\n}\n\ntemplate <>\n__device__ void convert(const float vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = vi;\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, float& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = static_cast<float>(t.as_half);\n}\n\ntemplate <>\n__device__ void convert(const at::Half vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = static_cast<float>(vi);\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, at::Half& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = t.as_half;\n}\n\ntemplate <typename GRAD_T>\n__global__ void strided_check_finite_cuda_kernel(\n        volatile int* noop_gmem,\n        GRAD_T* __restrict__ p_copy,\n        const size_t tsize,\n        int stride,\n        int clear_overflow_first)\n{\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;\n\n    if (clear_overflow_first) {\n        if (i == 0) {\n            *noop_gmem = 0;\n        }\n        __syncthreads();\n    }\n\n    for (int j = i; j < tsize; j+=totThreads) {\n        GRAD_T pi = p_copy[j];\n        if (!isfinite(pi)) {\n            *noop_gmem = 1;\n        }\n    }\n}\ntemplate <>\n__global__ void strided_check_finite_cuda_kernel(\n        volatile int* noop_gmem,\n        uint8_t* __restrict__ p_copy,\n        const size_t tsize,\n        int stride,\n        int clear_overflow_first)\n{\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;\n\n    if (clear_overflow_first) {\n        if (i == 0) {\n            *noop_gmem = 0;\n        }\n        __syncthreads();\n    }\n\n    for (int j = i; j < tsize; j+=totThreads) {\n        at::Half pi;\n        convert(p_copy[j], pi);\n        if (!isfinite(pi)) {\n            *noop_gmem = 1;\n        }\n    }\n}\n\ntemplate <typename FROM_T, typename TO_T> \n__global__ void maybe_cast_kernel(\n        volatile int* overflow_flag,\n        const FROM_T* p_in,\n        TO_T* p_out,\n        const size_t tsize)\n{\n    if (overflow_flag && *overflow_flag != 0) return;\n\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock);\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n    FROM_T pi[ILP];\n    TO_T po[ILP];\n\n    for(int j_start = 0;  j_start < tsize;  j_start+=totThreads*ILP) {\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            pi[ii] = 0;\n\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                pi[ii] = p_in[j];\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            convert(pi[ii], po[ii]);\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                p_out[j] = po[ii];\n            }\n        }\n    }\n}\n\ntemplate <typename T, typename GRAD_T, typename REDU_T>\n__global__ void reversible_adam_cuda_kernel(\n        T* __restrict__ p,\n        REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed\n        T* __restrict__ m,\n        T* __restrict__ v,\n        const GRAD_T * __restrict__ g,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        const size_t tsize,\n        adamMode_t mode,\n        const float decay)\n{\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock);\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n    T mi[ILP];\n    T vi[ILP];\n    T pi[ILP];\n    T gi[ILP];\n\n    bool overflow = false;\n    for(int j_start = 0;  j_start < tsize;  j_start+=totThreads*ILP) {\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            mi[ii] = T(0);\n            vi[ii] = T(0);\n            pi[ii] = T(0);\n            gi[ii] = GRAD_T(0);\n\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                pi[ii] = p[j];\n                mi[ii] = m[j];\n                vi[ii] = v[j];\n                gi[ii] = static_cast<T>(g[j]);\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            T scaled_grad = gi[ii]/grad_scale;\n            if (isfinite(scaled_grad)) {\n                mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;\n                vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;\n                float denom;\n                if (mode == ADAM_MODE_0)\n                    denom = sqrtf(vi[ii] + eps);\n                else // Mode 1\n                    denom = sqrtf(vi[ii]) + eps;\n                float update = (mi[ii]/denom) + (decay*pi[ii]);\n                pi[ii] = pi[ii] - (step_size*update);\n            } else {\n                overflow = true;\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            int j = j_start + i + totThreads*ii;\n            if (j < tsize) {\n                m[j] = mi[ii];\n                v[j] = vi[ii];\n                p[j] = pi[ii];\n                if (p_copy != NULL) {\n                    convert(pi[ii], p_copy[j]);\n                }\n            }\n        }\n    }\n\n    if (p_copy != NULL) {\n        __syncthreads();\n        if (overflow) {\n            convert(float(INFINITY), p_copy[0]);\n        }\n    }\n}\n\ntemplate <typename T, typename GRAD_T>\n__global__ void maybe_adam_undo_cuda_kernel(\n        volatile int* overflow_flag,\n        T* __restrict__ p,\n        T* __restrict__ m,\n        T* __restrict__ v,\n        const GRAD_T * __restrict__ g,\n        const float b1,\n        const float b2,\n        const float eps,\n        const float grad_scale,\n        const float step_size,\n        const size_t tsize,\n        adamMode_t mode,\n        const float decay)\n{\n    // NB! Skip undo kernel when overflow flag is NOT set\n    if (overflow_flag && *overflow_flag == 0) return;\n\n    //Assuming 2D grids and 2D blocks\n    const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n    const int threadsPerBlock = blockDim.x * blockDim.y;\n    const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n    const int i = (blockId * threadsPerBlock + threadIdInBlock);\n    const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;\n\n    T mi[ILP];\n    T vi[ILP];\n    T pi[ILP];\n    T gi[ILP];\n\n    for(int j_start = 0;  j_start < tsize;  j_start+=totThreads*ILP) {\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            mi[ii] = T(0);\n            vi[ii] = T(0);\n            pi[ii] = T(0);\n            gi[ii] = GRAD_T(0);\n\n            int j = j_start + i*ILP;\n            if (j < tsize) {\n                pi[ii] = p[j];\n                mi[ii] = m[j];\n                vi[ii] = v[j];\n                gi[ii] = static_cast<T>(g[j]);\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            T scaled_grad = gi[ii]/grad_scale;\n            if (isfinite(scaled_grad)) {\n                float denom;\n                if (mode == ADAM_MODE_0)\n                    denom = sqrtf(vi[ii] + eps);\n                else // Mode 1\n                    denom = sqrtf(vi[ii]) + eps;\n                pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay);\n                mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1;\n                vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;\n                // Make sure round off errors don't create (small) negative value.\n                // This can happen if we have to revert the very first step.\n                vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;\n            }\n        }\n\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++) {\n            int j = j_start + i*ILP;\n            if (j < tsize) {\n                m[j] = mi[ii];\n                v[j] = vi[ii];\n                p[j] = pi[ii];\n            }\n        }\n    }\n}\n\ntemplate <int DEPTH, typename FROM_T, typename TO_T>\nstruct MaybeCastFunctor\n{\n    __device__ __forceinline__ void operator()(\n        int chunk_size,\n        volatile int* overflow_flag,\n        TensorListMetadata<DEPTH>& tl)\n    {\n        if (overflow_flag && *overflow_flag != 0) return;\n\n        int tensor_loc = tl.block_to_tensor[blockIdx.x];\n        int chunk_idx = tl.block_to_chunk[blockIdx.x];\n        int n = tl.sizes[tensor_loc];\n\n        FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];\n        p_in += chunk_idx*chunk_size;\n        TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];\n        p_out += chunk_idx*chunk_size;\n\n        n -= chunk_idx*chunk_size;\n        int dim = chunk_size < n ? chunk_size : n;\n\n\tFROM_T pi[ILP];\n        TO_T po[ILP];\n\n        for(int j_start = 0;  j_start < dim;  j_start+=blockDim.x*ILP) {\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n                pi[ii] = FROM_T(0);\n                int j = j_start + threadIdx.x + ii*blockDim.x;\n                if (j < dim) {\n                    pi[ii] = p_in[j];\n                }\n            }\n\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n                convert(pi[ii], po[ii]);\n            }\n\n#pragma unroll\n            for(int ii = 0; ii < ILP; ii++) {\n                int j = j_start + threadIdx.x + ii*blockDim.x;\n                if (j < dim) {\n                    p_out[j] = po[ii];\n                }\n            }\n        }\n    }\n};\n\nvoid fused_strided_check_finite(\n\tat::Tensor & overflow_flag,\n        at::Tensor & p_copy,\n        int stride,\n\tint clear_overflow_first)\n{\n\t//Get tensor size\n\tint tsize = p_copy.numel();\n\tint niter = (tsize + stride - 1) / stride;\n\n\t//Determine #threads and #blocks\n\tconst int threadsPerBlock = 512;\n\t//In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set.\n\tconst dim3 blocks(clear_overflow_first ? 1 : (niter+threadsPerBlock-1)/threadsPerBlock);\n\tAT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), \"parameter tensor is too large to be indexed with int32\");\n\n\tcudaStream_t stream = at::cuda::getCurrentCUDAStream();\n        using namespace at; // prevents \"toString is undefined\" errors\n        DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, \"check_finite_cuda_kernel\",\n                strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                    overflow_flag.DATA_PTR<int>(),\n                    p_copy.DATA_PTR<scalar_t_0>(),\n                    tsize,\n                    stride,\n                    clear_overflow_first);\n                );\n\tTHCudaCheck(cudaGetLastError());\n}\n\nvoid fused_reversible_adam_cuda(\n        at::Tensor & p,\n        at::Tensor & p_copy,\n        at::Tensor & m,\n        at::Tensor & v,\n        at::Tensor & g,\n        float lr,\n        float beta1,\n        float beta2,\n        float eps,\n        float grad_scale,\n        int step,\n        int mode,\n        int bias_correction,\n        float decay)\n{\n//      using namespace at;\n\n      //Get tensor size\n      int tsize = p.numel();\n      //Determine #threads and #blocks\n      const int threadsPerBlock = 512;\n      const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n      AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n      //Constants\n      float step_size = 0;\n      if (bias_correction == 1) {\n          const float bias_correction1 = 1 - std::pow(beta1, step);\n          const float bias_correction2 = 1 - std::pow(beta2, step);\n          step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n      }\n      else {\n          step_size = lr;\n      }\n      cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n      if (g.scalar_type() == at::ScalarType::Half) {\n          //all other values should be fp32 for half gradients\n          AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n          //dispatch is done on the gradient type\n          using namespace at; // prevents \"toString is undefined\" errors\n          if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {\n              DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                      using accscalar_t = at::acc_type<scalar_t_0, true>;\n                      reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                          p.DATA_PTR<accscalar_t>(),\n                          p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,\n                          m.DATA_PTR<accscalar_t>(),\n                          v.DATA_PTR<accscalar_t>(),\n                          g.DATA_PTR<scalar_t_0>(),\n                          beta1,\n                          beta2,\n                          eps,\n                          grad_scale,\n                          step_size,\n                          tsize,\n                          (adamMode_t) mode,\n                          decay);\n                      );\n          } else {\n              AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, \"expected parameter to be of byte type\");\n              DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_e5m2_kernel\",\n                      using accscalar_t = at::acc_type<scalar_t_0, true>;\n                      reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(\n                          p.DATA_PTR<accscalar_t>(),\n                          p_copy.DATA_PTR<uint8_t>(),\n                          m.DATA_PTR<accscalar_t>(),\n                          v.DATA_PTR<accscalar_t>(),\n                          g.DATA_PTR<scalar_t_0>(),\n                          beta1,\n                          beta2,\n                          eps,\n                          grad_scale,\n                          step_size,\n                          tsize,\n                          (adamMode_t) mode,\n                          decay);\n                      );\n          }\n      } else {\n          using namespace at;\n          DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                  reversible_adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                      p.DATA_PTR<scalar_t_0>(),\n                      NULL, //don't output p_copy for fp32, it's wasted write\n                      m.DATA_PTR<scalar_t_0>(),\n                      v.DATA_PTR<scalar_t_0>(),\n                      g.DATA_PTR<scalar_t_0>(),\n                      beta1,\n                      beta2,\n                      eps,\n                      grad_scale,\n                      step_size,\n                      tsize,\n                      (adamMode_t) mode,\n                      decay);\n                  );\n      }\n      THCudaCheck(cudaGetLastError());\n}\n\nvoid maybe_cast_cuda(\n        at::Tensor & overflow_flag,\n        at::Tensor & p_in,\n        at::Tensor & p_out)\n{\n      //Get tensor size\n      int tsize = p_in.numel();\n      AT_ASSERTM(tsize == p_out.numel(), \"p_in.numel() must equal p_out.numel()\");\n      //Determine #threads and #blocks\n      const int threadsPerBlock = 512;\n      const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n      AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), \"parameter tensor is too large to be indexed with int32\");\n      //Constants\n      cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n      DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, \"maybe_cast_cuda\"\n              DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, \"maybe_cast_cuda\",\n                  maybe_cast_kernel<scalar_t_0,scalar_t_1><<<blocks,threadsPerBlock, 0, stream>>>(\n                      overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,\n                      p_in.DATA_PTR<scalar_t_0>(),\n                      p_out.DATA_PTR<scalar_t_1>(),\n                      tsize); ))\n      THCudaCheck(cudaGetLastError());\n}\n\nvoid maybe_cast_cuda_mt(\n    int chunk_size,\n    at::Tensor overflow_flag,\n    std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out\n{\n    //Constants\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    size_t tl_sz = tensor_lists.size();\n    AT_ASSERTM(tl_sz == 2, \"expected tensor lists of size 2\");\n\n    DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, \"maybe_cast_cuda_mt_kernel\",\n            DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, \"maybe_cast_cuda_mt_kernel\",\n                multi_tensor_apply<2>(\n                    BLOCK_SIZE,\n                    chunk_size,\n                    overflow_flag,\n                    tensor_lists,\n                    MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); ))\n    THCudaCheck(cudaGetLastError());\n}\n\nvoid fused_maybe_adam_undo_cuda(\n        at::Tensor & overflow_flag,\n        at::Tensor & p,\n        at::Tensor & m,\n        at::Tensor & v,\n        at::Tensor & g,\n        float lr,\n        float beta1,\n        float beta2,\n        float eps,\n        float grad_scale,\n        int step,\n        int mode,\n        int bias_correction,\n        float decay)\n{\n    //Get tensor size\n    int tsize = p.numel();\n    //Determine #threads and #blocks\n    const int threadsPerBlock = 512;\n    const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);\n    AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n    //Constants\n    float step_size = 0;\n    if (bias_correction == 1) {\n        const float bias_correction1 = 1 - std::pow(beta1, step);\n        const float bias_correction2 = 1 - std::pow(beta2, step);\n        step_size = lr * std::sqrt(bias_correction2)/bias_correction1;\n    }\n    else {\n        step_size = lr;\n    }\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    if (g.scalar_type() == at::ScalarType::Half) {\n        //all other values should be fp32 for half gradients\n        AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n        //dispatch is done on the gradient type\n        using namespace at; // prevents \"toString is undefined\" errors\n        DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                using accscalar_t = at::acc_type<scalar_t_0, true>;\n                maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                    overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,\n                    p.DATA_PTR<accscalar_t>(),\n                    m.DATA_PTR<accscalar_t>(),\n                    v.DATA_PTR<accscalar_t>(),\n                    g.DATA_PTR<scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    tsize,\n                    (adamMode_t) mode,\n                    decay);\n                );\n    } else {\n        using namespace at;\n        DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                maybe_adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(\n                    overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,\n                    p.DATA_PTR<scalar_t_0>(),\n                    m.DATA_PTR<scalar_t_0>(),\n                    v.DATA_PTR<scalar_t_0>(),\n                    g.DATA_PTR<scalar_t_0>(),\n                    beta1,\n                    beta2,\n                    eps,\n                    grad_scale,\n                    step_size,\n                    tsize,\n                    (adamMode_t) mode,\n                    decay);\n                );\n    }\n    THCudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  const float lr,\n  const float beta1,\n  const float beta2,\n  const float epsilon,\n  const int step,\n  const int bias_correction,\n  const float weight_decay,\n  const int grad_averaging,\n  const int mode,\n  const float global_grad_norm,\n  const float max_grad_norm);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n        m.def(\"lamb\", &multi_tensor_lamb_cuda, \"Computes and apply update for LAMB optimizer\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"type_shim.h\"\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntypedef enum{\n  MOMENT_MODE_0   =0, // L2 regularization mode\n  MOMENT_MODE_1   =1  // Decoupled weight decay mode\n} adamMode_t;\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::optional<bool> per_tensor_python);\n\nusing MATH_T = float;\n\ntemplate<typename T>\nstruct LAMBStage1Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<4>& tl,\n    const float beta1,\n    const float beta2,\n    const float beta3,\n    const float beta1_correction,\n    const float beta2_correction,\n    const float epsilon,\n    adamMode_t mode,\n    const float decay,\n    const float global_grad_norm,\n    const float max_global_grad_norm)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx*chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for(int i_start = 0;\n            i_start < n && i_start < chunk_size;\n            i_start += blockDim.x*ILP)\n    {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_m[ILP];\n      MATH_T r_v[ILP];\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        int i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          r_g[ii] = g[i];\n          // special ?optimization? for lamb stage 1\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          }\n          else {\n            r_p[ii] = p[i];\n          }\n          r_m[ii] = m[i];\n          r_v[ii] = v[i];\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_m[ii] = MATH_T(0);\n          r_v[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        if (mode == MOMENT_MODE_0) {\n\t  MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n\t  // L2 on scaled grad\n          scaled_grad = scaled_grad + decay*r_p[ii];\n          r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n          r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          r_p[ii] = next_m_unbiased / denom;\n        }\n        else {\n          MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n          r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n          r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);\n        }\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        int i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          g[i] = r_p[ii];\n          m[i] = r_m[ii];\n          v[i] = r_v[ii];\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate<typename T>\nstruct LAMBStage2Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<2>& tl,\n    const float* per_tensor_param_norm,\n    const float* per_tensor_update_norm,\n    const float learning_rate,\n    const float decay)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T ratio = learning_rate;\n    // apply adaptive learning rate to parameters with non-zero weight decay\n    if (decay != 0.0) \n    {\n      float param_norm = per_tensor_param_norm[tensor_num];\n      float update_norm = per_tensor_update_norm[tensor_num];\n      ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;\n    }\n\n    T* update = (T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    for(int i_start = 0;\n            i_start < n && i_start < chunk_size;\n            i_start += blockDim.x*ILP)\n    {\n      MATH_T r_p[ILP];\n      MATH_T r_update[ILP];\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n       \tint i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          r_p[ii] = p[i];\n          r_update[ii] = update[i];\n        }\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n       \tr_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n      }\n#pragma unroll\n      for(int ii = 0; ii < ILP; ii++)\n      {\n        int i = i_start + threadIdx.x + ii*blockDim.x;\n        if(i < n && i < chunk_size)\n        {\n          p[i] = r_p[ii];\n        }\n      }\n    }\n  }\n};\n\n\nvoid multi_tensor_lamb_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  const float lr,\n  const float beta1,\n  const float beta2,\n  const float epsilon,\n  const int step,\n  const int bias_correction,\n  const float weight_decay,\n  const int grad_averaging,\n  const int mode,\n  const float global_grad_norm,\n  const float max_grad_norm)\n{\n  using namespace at;\n  // Master weight and 32bit momentum(potentially changing) is not handled by this\n  // So we assume every tensor are all in the same type\n\n  // Handle bias correction mode\n  float bias_correction1 = 1.0f, bias_correction2 = 1.0f;\n  if (bias_correction == 1) {\n    bias_correction1 = 1 - std::pow(beta1, step);\n    bias_correction2 = 1 - std::pow(beta2, step);\n  }\n\n  // Handle grad averaging mode\n  float beta3 = 1.0f;\n  if (grad_averaging == 1) beta3 = 1 - beta1;\n\n  std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);\n  std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);\n\n  // Compute per tensor param norm\n  auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);\n\n  // We now in-place modify grad to store update before compute its norm\n  // Generally this is not a issue since people modify grad in step() method all the time\n  // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_1\",\n      multi_tensor_apply<4>(\n        BLOCK_SIZE,\n        chunk_size,\n        noop_flag,\n        tensor_lists,\n        LAMBStage1Functor<scalar_t_0>(),\n        beta1,\n        beta2,\n        beta3, // 1-beta1 or 1 depends on averaging mode\n        bias_correction1,\n        bias_correction2,\n        epsilon,\n        (adamMode_t) mode,\n        weight_decay,\n        global_grad_norm,\n        max_grad_norm); )\n\n  // Compute update norms\n  auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);\n\n  std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);\n\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_2\",\n      multi_tensor_apply<2>(\n        BLOCK_SIZE,\n        chunk_size,\n       \tnoop_flag,\n        grad_param_list,\n        LAMBStage2Functor<scalar_t_0>(),\n        std::get<1>(param_norm_tuple).DATA_PTR<float>(),\n        std::get<1>(update_norm_tuple).DATA_PTR<float>(),\n        lr,\n\tweight_decay); )\n\n  AT_CUDA_CHECK(cudaGetLastError());\n\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_fused_adam_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor per_tensor_eps,\n  at::Tensor per_tensor_weight_decay,\n  float lr,\n  float grad_scale,\n  int step,\n  int mode);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"multi_tensor_fused_adam\", &multi_tensor_fused_adam_cuda,\n        \"Multi tensor Adam optimized CUDA implementation.\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <THC/THCGeneral.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n#include <cmath>\n#include \"type_shim.h\"\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate<typename T>\n__device__ __forceinline__ bool is_aligned(T* p){\n  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntypedef enum{\n  ADAM_MODE_0   =0, // eps under square root\n  ADAM_MODE_1   =1  // eps outside square root\n} adamMode_t;\n\ntemplate <int DEPTH, typename T, typename GRAD_T>\nstruct DistAdamFunctor\n{\n  __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<DEPTH>& tl,\n    const float* per_tensor_beta1,\n    const float* per_tensor_beta2,\n    const int* per_tensor_bias_correction,\n    const float* per_tensor_eps,\n    const float* per_tensor_weight_decay,\n    const float lr,\n    const float grad_scale,\n    const int step,\n    adamMode_t mode)\n  {\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float b1 = per_tensor_beta1[tensor_num];\n    float b2 = per_tensor_beta2[tensor_num];\n    float eps = per_tensor_eps[tensor_num];\n    float decay = per_tensor_weight_decay[tensor_num];\n\n    float beta1_correction = 1.0f, beta2_correction = 1.0f;\n    if (per_tensor_bias_correction[tensor_num] == 1) {\n      beta1_correction = 1 - std::pow(b1, step);\n      beta2_correction = 1 - std::pow(b2, step);\n    }\n\n    T* p = (T *)tl.addresses[0][tensor_loc];\n    p += chunk_idx*chunk_size;\n    T* m = (T *)tl.addresses[1][tensor_loc];\n    m += chunk_idx*chunk_size;\n    T* v = (T *)tl.addresses[2][tensor_loc];\n    v += chunk_idx*chunk_size;\n    GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];\n    g += chunk_idx*chunk_size;\n    GRAD_T* p_copy = NULL;\n    if (DEPTH == 5) {\n      p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];\n      p_copy += chunk_idx*chunk_size;\n    }\n\n    n -= chunk_idx*chunk_size;\n    \n    T incoming_p[ILP];\n    T incoming_m[ILP];\n    T incoming_v[ILP];\n    T incoming_g[ILP];\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 &&\n      chunk_size % ILP == 0 &&\n      is_aligned(p) &&\n      is_aligned(m) &&\n      is_aligned(v) &&\n      is_aligned(g) &&\n      is_aligned(p_copy)) {\n      for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        GRAD_T tmp_g[ILP];\n        load_store(incoming_p, p, 0, i_start);\n        load_store(incoming_m, m, 0, i_start);\n        load_store(incoming_v, v, 0, i_start);\n        load_store(tmp_g, g, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          incoming_g[ii] = static_cast<T>(tmp_g[ii]);\n          T scaled_grad = incoming_g[ii]/grad_scale;\n          incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n          incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n          T next_m_unbiased = incoming_m[ii] / beta1_correction;\n\t  T next_v_unbiased = incoming_v[ii] / beta2_correction;\n\t  float denom;\n          if (mode == ADAM_MODE_0)\n            denom = sqrtf(next_v_unbiased + eps);\n          else // Mode 1\n            denom = sqrtf(next_v_unbiased) + eps;\n          float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);\n          incoming_p[ii] = incoming_p[ii] - (lr * update);\n\t  if (DEPTH == 5)  tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);\n        }\n        load_store(p, incoming_p, i_start, 0);\n        load_store(m, incoming_m, i_start, 0);\n        load_store(v, incoming_v, i_start, 0);\n        if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);\n      }\n    } else {\n      for (int i_start = 0;\n          i_start < n && i_start < chunk_size;\n          i_start += blockDim.x*ILP) {\n\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          incoming_p[ii] = 0;\n          incoming_m[ii] = 0;\n          incoming_v[ii] = 0;\n          incoming_g[ii] = 0;\n\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if (i < n && i < chunk_size) {\n            incoming_p[ii] = p[i];\n            incoming_m[ii] = m[i];\n            incoming_v[ii] = v[i];\n            incoming_g[ii] = static_cast<T>(g[i]);\n          }\n        }\n\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int j = i_start + threadIdx.x + ii*blockDim.x;\n\n          if (j < n && j < chunk_size) {\n            T scaled_grad = incoming_g[ii]/grad_scale;\n            m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;\n            v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;\n            T next_m_unbiased = m[j] / beta1_correction;\n            T next_v_unbiased = v[j] / beta2_correction;\n\t    float denom;\n            if (mode == ADAM_MODE_0)\n              denom = sqrtf(next_v_unbiased + eps);\n            else // Mode 1\n              denom = sqrtf(next_v_unbiased) + eps;\n            float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);\n            p[j] = incoming_p[ii] - (lr * update);\n\t    if (DEPTH == 5)  p_copy[j] = (GRAD_T) p[j];\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_fused_adam_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,  // p, m, v, g, p_copy\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor per_tensor_eps,\n  at::Tensor per_tensor_weight_decay,\n  float lr,\n  float grad_scale,\n  int step,\n  int mode)\n{\n  using namespace at;\n\n  size_t tl_sz = tensor_lists.size();\n  AT_ASSERTM(tl_sz == 4 || tl_sz == 5, \"expected tensor lists of size 4 or 5\");\n\n  if (tl_sz == 5) {\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"dist_adam_cuda_kernel\",  // g\n      using accscalar_t = at::acc_type<scalar_t_0, true>;\n      multi_tensor_apply<5>(\n        BLOCK_SIZE,\n        chunk_size,\n        noop_flag,\n        tensor_lists,\n        DistAdamFunctor<5, accscalar_t, scalar_t_0>(),\n        per_tensor_beta1.DATA_PTR<float>(),\n        per_tensor_beta2.DATA_PTR<float>(),\n        per_tensor_bias_correction.DATA_PTR<int>(),\n        per_tensor_eps.DATA_PTR<float>(),\n        per_tensor_weight_decay.DATA_PTR<float>(),\n        lr,\n        grad_scale,\n        step,\n        (adamMode_t) mode);\n    );\n  } else {\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"dist_adam_cuda_kernel\",  // g\n      using accscalar_t = at::acc_type<scalar_t_0, true>;\n      multi_tensor_apply<4>(\n        BLOCK_SIZE,\n        chunk_size,\n        noop_flag,\n        tensor_lists,\n        DistAdamFunctor<4, accscalar_t, scalar_t_0>(),\n        per_tensor_beta1.DATA_PTR<float>(),\n        per_tensor_beta2.DATA_PTR<float>(),\n        per_tensor_bias_correction.DATA_PTR<int>(),\n        per_tensor_eps.DATA_PTR<float>(),\n        per_tensor_weight_decay.DATA_PTR<float>(),\n        lr,\n        grad_scale,\n        step,\n        (adamMode_t) mode);\n    );\n  }\n  THCudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_compute_update_term_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_beta3,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor step,\n  at::Tensor per_tensor_epsilon,\n  const int mode,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_scale,\n  at::Tensor global_grad_norm,\n  const float max_grad_norm);\n\nvoid multi_tensor_lamb_update_weights_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_param_norm,\n  at::Tensor per_tensor_update_norm,\n  at::Tensor update_norm_offset,\n  at::Tensor learning_rate,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_grad_norm,\n  bool use_nvlamb);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"multi_tensor_lamb_compute_update_term\", &multi_tensor_lamb_compute_update_term_cuda,\n        \"Computes update term for LAMB optimizer\");\n  m.def(\"multi_tensor_lamb_update_weights\", &multi_tensor_lamb_update_weights_cuda,\n        \"Applies update term for LAMB optimizer\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"type_shim.h\"\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate<typename T>\n__device__ __forceinline__ bool is_aligned(T* p){\n  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename FROM_T, typename TO_T> \n__device__ void convert(const FROM_T vi, TO_T& vo)\n{\n    vo = static_cast<TO_T>(vi);\n}\n\ntemplate <>\n__device__ void convert(const float vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = vi;\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, float& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = static_cast<float>(t.as_half);\n}\n\ntemplate <>\n__device__ void convert(const at::Half vi, uint8_t& vo)\n{\n    union S\n    {\n\tfloat as_float;\n\tint as_int;\n    };\n    S s;\n    s.as_float = static_cast<float>(vi);\n    s.as_int = s.as_int & 0xFF800000;\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n    vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, at::Half& vo)\n{\n    union T\n    {\n        at::Half as_half;\n\tuint8_t as_byte[2];\n    };\n    T t;\n    t.as_byte[0] = 0;\n    t.as_byte[1] = vi;\n    vo = t.as_half;\n}\n\ntypedef enum{\n  MOMENT_MODE_0   =0, // L2 regularization mode\n  MOMENT_MODE_1   =1  // Decoupled weight decay mode\n} adamMode_t;\n\ntemplate<typename T, typename GRAD_T, typename MATH_T>\nstruct DistOptLAMBStage1Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<5>& tl,\n    const MATH_T* per_tensor_beta1,\n    const MATH_T* per_tensor_beta2,\n    const MATH_T* per_tensor_beta3,\n    const int* per_tensor_bias_correction,\n    const int* step,\n    const MATH_T* per_tensor_epsilon,\n    adamMode_t mode,\n    const MATH_T* per_tensor_decay,\n    const MATH_T* global_scale,\n    const MATH_T* global_grad_norm,\n    const float max_grad_norm)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    if (*noop_gmem == 1)\n        return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float combined_scale = *global_scale;\n    if (max_grad_norm > 0) {\n        combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);\n\tcombined_scale = *global_scale / std::min((float) 1.0, combined_scale);\n    }\n\n    MATH_T beta1 = per_tensor_beta1[tensor_num];\n    MATH_T beta2 = per_tensor_beta2[tensor_num];\n    MATH_T beta3 = 1 - beta1;\n    MATH_T beta1_correction, beta2_correction;\n    if (per_tensor_bias_correction[tensor_num] == 1) {\n        beta1_correction = 1 - pow(beta1, *step);\n        beta2_correction = 1 - pow(beta2, *step);\n    } else {\n        beta1_correction = (MATH_T) 1.0;\n        beta2_correction = (MATH_T) 1.0;\n    }\n    MATH_T epsilon = per_tensor_epsilon[tensor_num];\n    MATH_T decay = per_tensor_decay[tensor_num];\n\n    GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx*chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx*chunk_size;\n\n    MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc];\n    u += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    MATH_T r_g[ILP];\n    MATH_T r_p[ILP];\n    MATH_T r_m[ILP];\n    MATH_T r_v[ILP];\n    // to make things simple, we put aligned case in a different code path\n    if(n % ILP == 0 &&\n       chunk_size % ILP == 0 &&\n       is_aligned(g) &&\n       is_aligned(p) &&\n       is_aligned(m) &&\n       is_aligned(v))\n    {\n      GRAD_T l_g[ILP];\n      T l_p[ILP];\n      T l_m[ILP];\n      T l_v[ILP];\n      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)\n      {\n        // load\n        load_store(l_g, g, 0, i_start);\n        if (decay != 0)\n          load_store(l_p, p, 0, i_start);\n        load_store(l_m, m, 0, i_start);\n        load_store(l_v, v, 0, i_start);\n        // unpack\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          r_g[ii] = l_g[ii];\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          }\n          else {\n            r_p[ii] = l_p[ii];\n          }\n          r_m[ii] = l_m[ii];\n          r_v[ii] = l_v[ii];\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay*r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          }\n          else {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          l_m[ii] = r_m[ii];\n          l_v[ii] = r_v[ii];\n        }\n        // store\n        load_store(u, r_p, i_start, 0);\n        load_store(m, l_m, i_start, 0);\n        load_store(v, l_v, i_start, 0);\n      }\n    }\n    else\n    {\n      // see note in multi_tensor_scale_kernel.cu\n      for(int i_start = 0;\n          i_start < n && i_start < chunk_size;\n          i_start += blockDim.x*ILP)\n      {\n        MATH_T r_g[ILP];\n        MATH_T r_p[ILP];\n        MATH_T r_m[ILP];\n        MATH_T r_v[ILP];\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            r_g[ii] = g[i];\n            // special ?optimization? for lamb stage 1\n            if (decay == 0) {\n              r_p[ii] = MATH_T(0);\n            }\n            else {\n              r_p[ii] = p[i];\n            }\n            r_m[ii] = m[i];\n            r_v[ii] = v[i];\n          } else {\n            r_g[ii] = MATH_T(0);\n            r_p[ii] = MATH_T(0);\n            r_m[ii] = MATH_T(0);\n            r_v[ii] = MATH_T(0);\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay*r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          }\n          else {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            u[i] = r_p[ii];\n            m[i] = r_m[ii];\n            v[i] = r_v[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate<typename T, typename GRAD_T, typename MATH_T>\nstruct DistOptLAMBStage2Functor\n{\n   __device__ __forceinline__ void operator()(\n    int chunk_size,\n    volatile int* noop_gmem,\n    TensorListMetadata<3>& tl,\n    const MATH_T* per_tensor_param_norm,\n    const MATH_T* per_tensor_update_norm,\n    const long* update_norm_offset,\n    const MATH_T* learning_rate,\n    const MATH_T* per_tensor_decay,\n    const MATH_T* global_grad_norm,\n    bool use_nvlamb)\n  {\n    // I'd like this kernel to propagate infs/nans.\n    if (*noop_gmem == 1)\n        return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T decay = per_tensor_decay[tensor_num];\n\n    MATH_T ratio = *learning_rate;\n    // nvlamb: apply adaptive learning rate to all parameters\n    // otherwise, only apply to those with non-zero weight decay\n    if (use_nvlamb || (decay != (MATH_T) 0.0))\n    {\n      MATH_T param_norm = per_tensor_param_norm[tensor_num];\n      MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];\n      ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);\n    }\n\n    MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx*chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx*chunk_size;\n\n    GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc];\n    p_copy += chunk_idx*chunk_size;\n\n    n -= chunk_idx*chunk_size;\n\n    // to make things simple, we put aligned case in a different code path\n    if(n % ILP == 0 &&\n       chunk_size % ILP == 0 &&\n       is_aligned(p) &&\n       is_aligned(update))\n    {\n      T r_p[ILP];\n      MATH_T r_update[ILP];\n      GRAD_T r_p_copy[ILP];\n      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)\n      {\n        // load\n        load_store(r_p, p, 0, i_start);\n        load_store(r_update, update, 0, i_start);\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n\t  r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);\n          convert(r_p[ii], r_p_copy[ii]);\n        }\n        load_store(p, r_p, i_start, 0);\n        load_store(p_copy, r_p_copy, i_start, 0);\n      }\n    }\n    else\n    {\n      for(int i_start = 0;\n          i_start < n && i_start < chunk_size;\n          i_start += blockDim.x*ILP)\n      {\n        MATH_T r_p[ILP];\n        MATH_T r_update[ILP];\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            r_p[ii] = p[i];\n            r_update[ii] = update[i];\n          }\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          r_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n        }\n#pragma unroll\n        for(int ii = 0; ii < ILP; ii++)\n        {\n          int i = i_start + threadIdx.x + ii*blockDim.x;\n          if(i < n && i < chunk_size)\n          {\n            p[i] = r_p[ii];\n            convert(r_p[ii], p_copy[i]);\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_compute_update_term_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_beta1,\n  at::Tensor per_tensor_beta2,\n  at::Tensor per_tensor_beta3,\n  at::Tensor per_tensor_bias_correction,\n  at::Tensor step,\n  at::Tensor per_tensor_epsilon,\n  const int mode,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_scale,\n  at::Tensor global_grad_norm,\n  const float max_grad_norm)\n{\n  using namespace at;\n\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, \"lamb_stage_1\",\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 1, \"lamb_stage_1\",\n      DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, \"lamb_stage_1\",\n        multi_tensor_apply<5>(\n          BLOCK_SIZE,\n          chunk_size,\n          noop_flag,\n          tensor_lists,\n          DistOptLAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n          per_tensor_beta1.DATA_PTR<scalar_t_2>(),\n          per_tensor_beta2.DATA_PTR<scalar_t_2>(),\n          per_tensor_beta3.DATA_PTR<scalar_t_2>(),\n          per_tensor_bias_correction.DATA_PTR<int>(),\n          step.DATA_PTR<int>(),\n          per_tensor_epsilon.DATA_PTR<scalar_t_2>(),\n          (adamMode_t) mode,\n          per_tensor_decay.DATA_PTR<scalar_t_2>(),\n          global_scale.DATA_PTR<scalar_t_2>(),\n\t  global_grad_norm.DATA_PTR<scalar_t_2>(),\n\t  max_grad_norm); )))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid multi_tensor_lamb_update_weights_cuda(\n  int chunk_size,\n  at::Tensor noop_flag,\n  std::vector<std::vector<at::Tensor>> tensor_lists,\n  at::Tensor per_tensor_param_norm,\n  at::Tensor per_tensor_update_norm,\n  at::Tensor update_norm_offset,\n  at::Tensor learning_rate,\n  at::Tensor per_tensor_decay,\n  at::Tensor global_grad_norm,\n  bool use_nvlamb)\n{\n  using namespace at;\n\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, \"lamb_stage_2\",\n    DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[2][0].scalar_type(), 1, \"lamb_stage_2\",\n      DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 2, \"lamb_stage_2\",\n        multi_tensor_apply<3>(\n          BLOCK_SIZE,\n          chunk_size,\n          noop_flag,\n          tensor_lists,\n          DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n          per_tensor_param_norm.DATA_PTR<scalar_t_2>(),\n          per_tensor_update_norm.DATA_PTR<scalar_t_2>(),\n          update_norm_offset.DATA_PTR<long>(),\n\t  learning_rate.DATA_PTR<scalar_t_2>(),\n          per_tensor_decay.DATA_PTR<scalar_t_2>(),\n\t  global_grad_norm.DATA_PTR<scalar_t_2>(),\n          use_nvlamb); )))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/transducer/transducer_joint.cpp",
    "content": "#include <torch/extension.h>\n#include <ATen/Functions.h>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> transducer_joint_cuda_forward(\n    torch::Tensor f,\n    torch::Tensor g,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int64_t packedBatch,\n    int opt,\n    bool packOutput,\n    bool relu,\n    bool dropout,\n    float dropoutProb,\n    int tileSize);\n\n\nstd::vector<torch::Tensor> transducer_joint_cuda_backward(\n    std::vector<torch::Tensor> in,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int maxGLen,\n    bool packOutput,\n    float scale);\n\nstd::vector<torch::Tensor> transducer_joint_forward(\n    torch::Tensor f,\n    torch::Tensor g,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int64_t packedBatch,\n    int opt,\n    bool packOutput,\n    bool relu,\n    bool dropout,\n    float dropoutProb,\n    int tileSize) {\n    CHECK_INPUT(f);\n    CHECK_INPUT(g);\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(gLen);\n    if (packOutput)\n        CHECK_INPUT(batchOffset);\n    return transducer_joint_cuda_forward(\n        f, \n        g, \n        fLen, \n        gLen,\n        batchOffset,\n        packedBatch,\n        opt,\n        packOutput,\n        relu,\n        dropout,\n        dropoutProb,\n        tileSize);\n}\n\nstd::vector<torch::Tensor> transducer_joint_backward(\n    std::vector<torch::Tensor> in,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int maxGLen,\n    bool packOutput,\n    float scale) {\n    for (auto t : in){\n        CHECK_INPUT(t);\n    }\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(gLen);\n    if (packOutput)\n        CHECK_INPUT(batchOffset);\n    return transducer_joint_cuda_backward(\n        in, \n        fLen, \n        gLen,\n        batchOffset,\n        maxFLen,\n        maxGLen,\n        packOutput,\n        scale);\n}\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &transducer_joint_forward, \"transducer joint forward (CUDA)\");\n  m.def(\"backward\", &transducer_joint_backward, \"transducer joint backward (CUDA)\");\n}"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/transducer/transducer_joint_kernel.cu",
    "content": "#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <c10/macros/Macros.h>\n#include <THC/THC.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/CUDAGeneratorImpl.h>\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n#include <curand_kernel.h>\n#include \"philox.h\"\n\n// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.\n// width should be a power of 2 and should be less than warpSize.\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){\n    for (unsigned offset = width/2; offset > 0; offset /= 2){\n        x += __shfl_down_sync(0xffffffff, x, offset, width);   \n    }\n    return x;\n}\n\ninline int largestPowerOfTwo(int x){\n    int y = 1;\n    while (y <= x)\n        y <<= 1;\n    return y >> 1;\n}\n\n/*\nFigure out vectorization type for masks.\nSimilar to how PyTorch figures out acc_t here:\naten/src/ATen/AccumulateType.h \n*/\ntemplate <int V>\nstruct MaskVecType { };\n\ntemplate <> struct MaskVecType<1> { using type = uint8_t; };\ntemplate <> struct MaskVecType<2> { using type = uint16_t; };\ntemplate <> struct MaskVecType<4> { using type = uint32_t; };\n\ntemplate<int V>\nusing mvec_type = typename MaskVecType<V>::type;\n\n// Helper class to calculate pointer offset that can be shared by different flavors of kernels.\n// For fwd, batch offset and stride are different for packing and non-packing mode.\nstruct OffsetCalFwd{\n    __device__ __forceinline__ OffsetCalFwd(\n        int64_t batch, \n        const int64_t *batchOffset, \n        int64_t maxFLen, \n        int64_t maxGLen, \n        int64_t gLen,\n        int64_t hiddenSize,\n        bool packOutput) :\n        batch(batch),\n        batchOffset(batchOffset),\n        maxFLen(maxFLen),\n        maxGLen(maxGLen),\n        gLen(gLen),\n        hiddenSize(hiddenSize),\n        packOutput(packOutput)\n        {}\n    \n    int64_t batch;\n    const int64_t *batchOffset;\n    int64_t maxFLen;\n    int64_t maxGLen;\n    int64_t gLen;\n    int64_t hiddenSize;\n    bool packOutput;\n\n    __device__ __forceinline__ int64_t getBatchOffset(){\n        return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize \n                            : batch*maxFLen*maxGLen*hiddenSize;\n    }\n\n    __device__ __forceinline__ int64_t getStrideF(){\n        return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize;\n    }\n\n    \n};\n\n// Helper class to calculate pointer offset that can be shared by different flavors of kernels\n// For bwd, batch offset and stride are different for packing and non-packing mode.\n// The reducion is done for two input tensors. Therefore, generating two sets of offsets\n// according to bwdFasterDim can lead to a unified implementation in the actual kernel.\nstruct OffsetCalBwd{\n    __device__ __forceinline__ OffsetCalBwd(\n        int64_t batch, \n        const int64_t *batchOffset, \n        const int *fLen, \n        const int *gLen,\n        int64_t maxFLen, \n        int64_t maxGLen, \n        int64_t hiddenSize,\n        bool packOutput,\n        bool bwdFasterDim) :\n        batch(batch),\n        batchOffset(batchOffset),\n        maxFLen(maxFLen),\n        maxGLen(maxGLen),\n        fLen(fLen),\n        gLen(gLen),\n        hiddenSize(hiddenSize),\n        packOutput(packOutput),\n        bwdFasterDim(bwdFasterDim)\n        {}\n\n    int64_t batch;\n    const int64_t *batchOffset;\n    const int *fLen;\n    const int *gLen;\n    int64_t maxFLen;\n    int64_t maxGLen;\n    int64_t hiddenSize;\n    bool packOutput;\n    bool bwdFasterDim;  // whether doing bwd on the faster moving dimension\n\n    __device__ __forceinline__ int64_t getBatchOffset(){\n        return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize \n                            : batch*maxFLen*maxGLen*hiddenSize;\n    }\n\n    __device__ __forceinline__ int64_t getMaxXLen(){\n        return bwdFasterDim ? maxGLen : maxFLen;\n    }\n\n    __device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){\n        return bwdFasterDim ? gLen[batch] : fLen[batch];\n    }\n\n    __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){\n        return bwdFasterDim ? fLen[batch] : gLen[batch];\n    }\n    \n    __device__ __forceinline__ int64_t getStrideX(){\n        return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize);\n    }\n\n    __device__ __forceinline__ int64_t getStrideY(){\n        return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize;\n    }\n};\n\n\n// Vanila transducer joint forward kernel\n// Detail of this joint function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n\n// f is a tensor of shape [batch, T, H]\n// g is a tensor of shape [batch, U, H]\n// the transducer joint does\n// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)\n// The resultant tensor is of shape [batch, T, U, H]\n// Each thread block is working on one \"batch\" of data in the output tensor, [batch, t, u, :]\n\n// This joint function can optionally pack the output where the output tensor with a shape of\n// [B, T, U, H] is packed into [B_packed, H].\n// Don't-care region (t > fLen) or (u > gLen) is removed.\n// To enable packing, the starting offset for each batch need to be specified with batchOffset.\ntemplate <typename scalar_t, class OffsetCal>\n__global__ void transducer_joint_forward(\n    const scalar_t *f,\n    const scalar_t *g,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    scalar_t *sum) {\n\n\n    const int batch = blockIdx.z;\n    const int t = blockIdx.y;\n    const int u = blockIdx.x;\n    const auto myFLen = fLen[batch];\n    const auto myGLen = gLen[batch];\n\n    OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);\n    const auto myBatchOffset = offsetCal.getBatchOffset();\n    const auto strideF = offsetCal.getStrideF();\n    scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize;\n    scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize;\n    scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize;\n\n    if (t < myFLen and u < myGLen){\n        #pragma unroll\n        for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){\n            if (h < hiddenSize){\n                mySum[h] = myF[h] + myG[h];\n            }\n        }\n    }\n    else if (packOutput == false and t < maxFLen and u < maxGLen){\n        // Need to write finite data to don't-care region because we instantiate the result tensor\n        // with torch::empty for performance reasons. Even though it is don't-care region, the \n        // contents need to be finite, otherwise could lead to NaN in WGRAD.\n        // In packing mode, this write is no longer necessary as we remove the don't-care region\n        // from the output.\n        // Picking -1 (over 0) here for ease of testing.\n        #pragma unroll\n        for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){\n            if (h < hiddenSize){\n                mySum[h] = -1;\n            }\n        }    \n    }\n}\n\n/*\nTiled version of the joint forward kernel\nDetail of this joint function can be found in: \n[1] Sequence Transduction with Recurrent Neural Networks.\n\nf is a tensor of shape [batch, T, H]\ng is a tensor of shape [batch, U, H]\nthe transducer joint does\nsum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)\nThe resultant tensor is of shape [batch, T, U, H]\nEach thread is working on a tile of the shape of tileF x tileG in the result tensor.\nThe input for the tile is first loaded in the register and is reused tileG and tileF times. \n\nThis joint function can optionally pack the output where the output tensor with a shape of\n[B, T, U, H] is packed into [B_packed, H].\nDon't-care region (t > fLen) or (u > gLen) is removed.\nTo enable packing, the starting offset for each batch need to be specified with batchOffset.\n\nOptionally this joint function performs ReLU and/or dropout on the joint output, which is \ncontrolled by arguments relu and dropout, respectively. philoxArgs is argument used for generating\npseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint\nfunction is a masked operation, which is controlled by the template argument masked. In this case, \nmasks are saved to backward.\n*/\ntemplate <typename scalar_t, int tileF, int tileG, int U, class OffsetCal, bool masked>\n__global__ void transducer_joint_tiled_forward(\n    const scalar_t *f,\n    const scalar_t *g,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    int64_t hiddenPerBlock,\n    bool packOutput,\n    bool relu, \n    bool dropout,\n    float p,\n    at::PhiloxCudaState philoxArgs,\n    scalar_t *sum,\n    uint8_t *mask) {\n\n    static_assert(U == 4, \"U has to be 4, as random numbers are generated in batch of 4\");\n\n    const int batch = blockIdx.z;\n    const int t = blockIdx.y * tileF;\n    const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;\n    const int u = blockIdx.x / hiddenBlock * tileG;\n    const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock;\n    const int h = threadIdx.x;\n    const auto myFLen = fLen[batch];\n    const auto myGLen = gLen[batch];\n\n    OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);\n    const auto myBatchOffset = offsetCal.getBatchOffset();\n    const auto strideF = offsetCal.getStrideF();\n\n    scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset;\n    scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset;\n    scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset;\n    uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset;\n\n    // The following code is only needed for dropout. We try to bypass them as much as possible.\n    auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) \n                            : std::make_tuple(static_cast<uint64_t>(0), static_cast<uint64_t>(0));\n    uint64_t tid = masked ? (static_cast<uint64_t>(blockIdx.z)*gridDim.y*gridDim.x + \n                        blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x\n                            : 0;\n    Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); \n    scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0;  \n    bool dropoutMask[U];\n\n    if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){    \n        // register buffers for tiled input reuse\n        scalar_t fBuffer[tileF], gBuffer[tileG];    \n        for (int i = 0; i < tileF; ++i){\n            if (t + i < myFLen)\n                fBuffer[i] = myF[i*hiddenSize + h];\n        }\n        for (int j = 0; j < tileG; ++j){\n            if (u + j < myGLen)\n                gBuffer[j] = myG[j*hiddenSize + h];\n        }\n        #pragma unroll\n        for (int i = 0; i < tileF; ++i){\n            if (t + i < myFLen){\n                #pragma unroll\n                for (int j = 0; j < tileG; ++j){\n                    int idx = i*tileG + j;\n                    if (masked and dropout and idx % U == 0){\n                        // For performance, generate 4 random numbers in one shot\n                        // auto rand4 = curand_uniform4(&state);\n                        auto rand4 = uniform4(ph());\n                        dropoutMask[0] = rand4.x < p;\n                        dropoutMask[1] = rand4.y < p;\n                        dropoutMask[2] = rand4.z < p;\n                        dropoutMask[3] = rand4.w < p;\n                    }\n\n                    if (u + j < myGLen){\n                        scalar_t out = fBuffer[i] + gBuffer[j];\n                        if (masked){\n                            // Apply ReLU here when relu is True\n                            bool localMask = relu ? (out>0) : 1;\n                            localMask = dropout ? localMask & dropoutMask[idx%U] : localMask;\n                            out = dropout ? out*localMask*scale : out*localMask;\n                            myMask[i*strideF + j*hiddenSize + h] = static_cast<uint8_t>(localMask);\n                        }\n                        mySum[i*strideF + j*hiddenSize + h] = out;\n                    }\n                    else if (packOutput == false and u + j < maxGLen)\n                        mySum[i*strideF + j*hiddenSize + h] = -1;\n                }\n            }\n            else if (packOutput == false and t + i < maxFLen){\n                // Again need to write finite data to don't-care region\n                #pragma unroll\n                for (int j = 0; j < tileG; ++j){\n                    if (u + j < maxGLen)\n                        mySum[i*strideF + j*hiddenSize + h] = -1;\n                }\n            }\n        }\n    }\n    else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){\n        // Only need to ensure the finity in normal mode\n        #pragma unroll\n        for (int i = 0; i < tileF; ++i){\n            if (t + i < maxFLen){\n                #pragma unroll\n                for (int j = 0; j < tileG; ++j){\n                    if (u + j < maxGLen)\n                        mySum[i*strideF + j*hiddenSize + h] = -1;\n                }\n            }\n        }\n    }\n}\n\n/*\nBwd operation (reduction) on one input tensor. Since the operation performed for the two input\ntensors are exactly the same, only one kernel is needed, and the different indexing offsets\nand strides are handled by OffsetCalBwd.\n\nWhen packing is enabled in the fwd op, unpacking is needed to restore the gradients in a \nnon-packed form.\n\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, class OffsetCal, bool masked>\n__device__ void transducer_joint_single_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    bool bwdFasterDim,  // whether bwd on the faster moving dimension (u)\n    float scale,\n    scalar_t *inGrad,\n    int yBlockOffset=0) {\n\n\n    const int batch = blockIdx.z;\n    // For the second input tensor, this offset need to be subtracted because the first yBlockOffset\n    // sets of thread blocks are for the first input tensor.\n    const int x = blockIdx.y-yBlockOffset;\n    const int hOffset = blockIdx.x*C10_WARP_SIZE;\n    const int wid = threadIdx.y;\n    const int lid = threadIdx.x;\n    const int numWarp = blockDim.y;\n    extern __shared__ char smem8[];\n    auto smem = reinterpret_cast<acc_t*>(smem8);\n\n    OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, \n                        bwdFasterDim);\n    const auto maxXLen = offsetCal.getMaxXLen();\n    const auto myXLen = offsetCal.getMyXLen();\n    const auto myYLen = offsetCal.getMyYLen();\n    scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;\n    \n    if (x < myXLen){\n        \n        const auto myBatchOffset = offsetCal.getBatchOffset();\n        const auto strideX = offsetCal.getStrideX();\n        const auto strideY = offsetCal.getStrideY();\n        const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;\n        const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr;\n        \n        // Each warp reduces numYPerWarp \"y\" first\n        acc_t warpSum = 0;\n        auto numYPerWarp = (myYLen+numWarp-1)/numWarp;\n        #pragma unroll\n        for (int warpY = 0; warpY < numYPerWarp; ++warpY){\n            auto y = wid*numYPerWarp + warpY;\n            if (y < myYLen and (hOffset+lid) < hiddenSize)\n                if (masked)\n                    warpSum += static_cast<acc_t>(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale;\n                else    \n                    warpSum += myGrad[y*strideY + lid];\n        }\n\n        // transpose partial sum in SMEM and reduce further using warpReduce\n        smem[lid*numWarp + wid] = warpSum;\n        __syncthreads();\n        auto sum = smem[wid*C10_WARP_SIZE + lid];\n        sum = warpReduce(sum, numWarp);\n\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // example of 4 warps (a, b, c, d) with 8 threads per warp\n        // Each warp need 8 / 4 = 2 threads to write the results.\n        if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){\n            if (lid % numWarp == 0){\n                myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum;\n            }\n        }\n    }\n    else if (wid == 0 and hOffset + lid < hiddenSize){\n        // Need to ensure the grad is zero for don't care region\n        myInGrad[lid] = 0;\n    }\n}\n\n/*\nActual bwd (reduction) kernel get launched.\nCall transducer_joint_single_backward twice on two input tensors. \nThe two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op \nuses the rest.\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, class OffsetCal, bool masked>\n__global__ void transducer_joint_combined_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    float scale,\n    scalar_t *fGrad,\n    scalar_t *gGrad) {\n    if (blockIdx.y < maxFLen){\n        transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            false,\n            scale,\n            fGrad);\n    }\n    else{\n        transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            true,\n            scale,\n            gGrad,\n            maxFLen);\n    }  \n}\n\n/*\nVectorized version of transducer_joint_single_backward\nDoing exact same operation as transducer_joint_single_backward except the load and store are\nvectorized.\nWhen packing is enabled in the fwd op, unpacking is needed to restore the gradients in a \nnon-packed form.\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>\n__device__ void transducer_joint_single_vec_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    bool bwdFasterDim,\n    float scale,\n    scalar_t *inGrad,\n    int yBlockOffset=0){\n\n    const int batch = blockIdx.z;\n    const int x = blockIdx.y - yBlockOffset;\n    const int hOffset = blockIdx.x*C10_WARP_SIZE*V;\n    const int wid = threadIdx.y;\n    const int lid = threadIdx.x;\n    const int numWarp = blockDim.y;\n\n    // Figure out the vectorization type for mask\n    using mvec_t = mvec_type<V>;\n\n    OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, \n                        bwdFasterDim);\n    const auto maxXLen = offsetCal.getMaxXLen();\n    const auto myXLen = offsetCal.getMyXLen();\n    const auto myYLen = offsetCal.getMyYLen();\n    scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;\n    extern __shared__ char smem8[];\n    auto smem = reinterpret_cast<acc_t*>(smem8);\n\n    acc_t warpSum[V];\n    scalar_t inBuffer[V];\n    uint8_t maskBuffer[V];\n    scalar_t outBuffer[V];\n    auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);\n    auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);\n\n    if (x < myXLen){\n        const auto myBatchOffset = offsetCal.getBatchOffset();\n        const auto strideX = offsetCal.getStrideX();\n        const auto strideY = offsetCal.getStrideY();\n        const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;\n        const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset\n                                            :nullptr;\n\n        for (int i = 0; i < V; ++i)\n            warpSum[i] = 0;\n\n        // Each warp reduces numYPerWarp \"y\" first\n        auto numYPerWarp = (myYLen+numWarp-1)/numWarp;\n        for (int warpY = 0; warpY < numYPerWarp; ++warpY){\n            auto y = wid*numYPerWarp + warpY;\n            auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY);\n            auto myMaskVec = masked ? reinterpret_cast<mvec_t const *>(myMask + y*strideY)\n                                        : nullptr;\n            auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);\n            auto maskBufferVec = reinterpret_cast<mvec_t*>(maskBuffer);\n            if (hOffset + lid*V < hiddenSize and y < myYLen){\n                *inBufferVec = myGradVec[lid];  // vectorized load\n                if (masked){\n                    *maskBufferVec = myMaskVec[lid];\n                    #pragma unroll\n                    for (int i = 0; i < V; ++i)\n                        warpSum[i] += static_cast<acc_t>(inBuffer[i]) * maskBuffer[i] * scale;\n                }\n                else{\n                    #pragma unroll\n                    for (int i = 0; i < V; ++i)\n                        warpSum[i] += inBuffer[i];\n                }\n            }\n        }\n        \n        // transpose partial sum in SMEM and reduce further using warpReduce\n        for (int i = 0; i < V; ++i){\n            smem[lid*numWarp + wid] = warpSum[i];\n            __syncthreads();\n            auto sum = smem[wid*C10_WARP_SIZE + lid];\n\n            if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){\n                sum = warpReduce(sum, numWarp);\n                if (lid % numWarp == 0){\n                    outBuffer[i] = sum;\n                }\n            }\n            __syncthreads();\n        }\n\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // a a b b c c d d\n        // example of 4 warps (a, b, c, d) with 8 threads per warp\n        // Each warp need 8 / 4 = 2 threads to write the results.\n        if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize)\n            myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec;     \n    }\n    else if (wid == 0 and hOffset + lid*V < hiddenSize){\n        // Need to ensure the grad is zero for don't care region\n        myInGradVec[lid] = 0;\n    }\n}\n\n/*\nVecotrized version of transducer_joint_combined_backward\nCall transducer_joint_single_vec_backward twice on two input tensors. \nThe two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op \nuses the rest.\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\nand mask contains the mask information.\n*/\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>\n__global__ void transducer_joint_combined_vec_backward(\n    const scalar_t *grad,\n    const uint8_t *mask,\n    const int *fLen,\n    const int *gLen,\n    const int64_t *batchOffset,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    int64_t hiddenSize,\n    bool packOutput,\n    float scale,\n    scalar_t *fGrad,\n    scalar_t *gGrad) {\n    if (blockIdx.y < maxFLen){\n        transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            false,\n            scale,\n            fGrad);\n    }\n    else{\n        transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(\n            grad,\n            mask,\n            fLen,\n            gLen,\n            batchOffset,\n            maxFLen,\n            maxGLen,\n            hiddenSize,\n            packOutput,\n            true,\n            scale,\n            gGrad,\n            maxFLen);\n    }  \n}\n\n\n\n\nstd::vector<torch::Tensor> transducer_joint_cuda_forward(\n    torch::Tensor f,\n    torch::Tensor g,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int64_t packedBatch,\n    int opt,\n    bool packOutput,\n    bool relu,\n    bool dropout,\n    float dropoutProb,\n    int tileSize){\n\n    \n    auto tensorOpt = f.options();\n    auto dtype = f.scalar_type();\n    const auto batchSize = f.size(0);\n    const auto maxFLen = f.size(1);\n    const auto maxGLen = g.size(1);\n    const auto hiddenSize = f.size(2);\n    bool masked = dropout or relu;\n    \n    int64_t *batchOffsetPtr = nullptr;\n    torch::Tensor sum, mask;\n    auto maskOpt = tensorOpt.dtype(torch::kUInt8);\n    if (!packOutput){\n        sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);\n        batchOffsetPtr = nullptr;\n        if (masked)\n            mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt);\n    }\n    else{\n        sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);    \n        batchOffsetPtr = batchOffset.data_ptr<int64_t>();\n        if (masked)\n            mask = torch::empty({packedBatch, hiddenSize}, maskOpt);\n    }\n    uint8_t *maskPtr = masked ? mask.data_ptr<uint8_t>() : nullptr;\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    TORCH_CHECK(opt == 0 or opt == 1, \"Got an invalid optimization level \", opt);\n    // Simple heuristics\n    const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)\n                                        / C10_WARP_SIZE * C10_WARP_SIZE);\n    \n    if (opt == 0){\n        // vanilla kernel\n        const int threads = numThread;\n        const dim3 blocks(maxGLen, maxFLen, batchSize);\n\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_joint_forward\", ([&] {\n            transducer_joint_forward<scalar_t, OffsetCalFwd>\n            <<<blocks, threads, 0, stream>>>(\n                f.data_ptr<scalar_t>(), \n                g.data_ptr<scalar_t>(), \n                fLen.data_ptr<int>(), \n                gLen.data_ptr<int>(), \n                batchOffsetPtr,\n                maxFLen,\n                maxGLen,\n                hiddenSize,\n                packOutput,\n                sum.data_ptr<scalar_t>());\n        }));  \n    }\n    if (opt == 1){\n        // tiled version. For simplicity, assume tileF == tileG, even though the kernel can \n        // support more general cases.\n        const int threads = numThread;\n        const int hiddenPerBlock = numThread;\n        const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;\n        const dim3 blocks(  (maxGLen+tileSize-1)/tileSize * hiddenBlock, \n                            (maxFLen+tileSize-1)/tileSize, \n                            batchSize);\n\n        TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, \n                \"Expected tileSize to be in [1, 2, 4], but got \", tileSize);\n\n        at::PhiloxCudaState rng_engine_inputs;\n        if (masked){\n            // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler \n            // for non-masked calls.\n            // Therefore no need to initialize.\n            c10::optional<at::Generator> gen_;\n            auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, \n                                                    at::cuda::detail::getDefaultCUDAGenerator());\n            // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, \n            // each thread processes tileF * tileG output elements. \n            int64_t counterOffset = tileSize * tileSize;\n            {\n                std::lock_guard<std::mutex> lock(gen->mutex_);\n                rng_engine_inputs = gen->philox_cuda_state(counterOffset);\n            }\n        }\n\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_joint_forward\", ([&] {\n            void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, \n                            int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, \n                            at::PhiloxCudaState, scalar_t*, uint8_t*);\n            if (masked){\n                switch (tileSize){\n                    case 2:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd, \n                                                                    true>;\n                        break;\n                    case 4:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd, \n                                                                    true>;\n                        break;\n                }\n            }\n            else{\n                switch (tileSize){\n                    case 1:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 1, 1, 4, OffsetCalFwd, \n                                                                    false>;\n                        break;\n                    case 2:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd, \n                                                                    false>;\n                        break;\n                    case 4:\n                        kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd, \n                                                                    false>;\n                        break;\n                }\n            }\n            \n            kernel<<<blocks, threads, 0, stream>>>(\n                f.data_ptr<scalar_t>(),\n                g.data_ptr<scalar_t>(),\n                fLen.data_ptr<int>(),\n                gLen.data_ptr<int>(),\n                batchOffsetPtr,\n                maxFLen,\n                maxGLen,\n                hiddenSize,\n                hiddenPerBlock,\n                packOutput,\n                relu,\n                dropout,\n                1.0f - dropoutProb,\n                rng_engine_inputs,\n                sum.data_ptr<scalar_t>(),\n                maskPtr);\n        }));  \n    }\n \n    THCudaCheck(cudaGetLastError());\n    if (masked) \n        return {sum, mask};\n    else\n        return {sum};\n}\n\nstd::vector<torch::Tensor> transducer_joint_cuda_backward(\n    std::vector<torch::Tensor> in,\n    torch::Tensor fLen,\n    torch::Tensor gLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int maxGLen,\n    bool packOutput,\n    float scale){\n\n    auto grad = in[0];\n    bool masked = (in.size() == 2);\n    uint8_t *maskPtr = masked ? in[1].data_ptr<uint8_t>() : nullptr;\n\n    auto tensorOpt = grad.options();\n    auto dtype = grad.scalar_type();\n    const int batchSize = fLen.size(0);\n    const int hiddenSize = grad.size(-1);\n\n    const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n    const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;\n\n    torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);\n    torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);\n\n    int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr<int64_t>(); \n\n    // The number \"y\" I would like each thread to work on\n    const int workPerThread = 32;   \n    // Since the bwd for f and g have the same thread block size, we need to use the max of the two.\n    int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread);\n    // Would like to have at least 2 warps \n    numWarp = std::max(2, numWarp);\n    // cap on the maximum number of warps allowed\n    numWarp = std::min(maxNumWarp, numWarp); \n\n    // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape\n    // numWarp x warpSize\n    const int smemSize = numWarp * C10_WARP_SIZE;\n    const dim3 threads(C10_WARP_SIZE, numWarp, 1);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_joint_cuda_backward_kernel\", ([&] {\n        auto gradPtr = grad.data_ptr<scalar_t>();\n        auto fLenPtr = fLen.data_ptr<int>();\n        auto gLenPtr = gLen.data_ptr<int>(); \n        auto fGradPtr = fGrad.data_ptr<scalar_t>();\n        auto gGradPtr = gGrad.data_ptr<scalar_t>();\n\n        // resolve the acc_t type\n        using acc_t = at::acc_type<scalar_t, true>;\n        using vec_t = uint64_t;\n\n        constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);\n        constexpr int vecAlignment = std::alignment_of<vec_t>::value;\n\n        // if all input and output tensors meet the alignment requirement\n        bool memAlign = (reinterpret_cast<uint64_t>(gradPtr) % vecAlignment == 0) \n                        and (reinterpret_cast<uint64_t>(fGradPtr) % vecAlignment == 0) \n                        and (reinterpret_cast<uint64_t>(gGradPtr) % vecAlignment == 0);\n\n        if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){\n            // If vectorization helps and the alignment requirement is met, use the vectorized \n            // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.\n            const dim3 blocks(  (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), \n                                maxFLen+maxGLen, \n                                batchSize);\n            if (masked){\n                transducer_joint_combined_vec_backward\n                    <scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, true>\n                    <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);\n            }\n            else{\n                transducer_joint_combined_vec_backward\n                <scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, false>\n                <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);    \n            }\n        }\n        else{\n            const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, \n                                maxFLen + maxGLen, batchSize);\n            if (masked){\n                transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>\n                <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);\n            }\n            else{\n                transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, false>\n                <<<blocks, threads, smemSize*sizeof(acc_t)>>>(\n                    gradPtr,\n                    maskPtr,\n                    fLenPtr, \n                    gLenPtr, \n                    batchOffsetPtr, \n                    maxFLen,\n                    maxGLen,\n                    hiddenSize,\n                    packOutput,\n                    scale,\n                    fGradPtr,\n                    gGradPtr);\n            }\n        }\n    }));   \n\n    return {fGrad, gGrad};\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/transducer/transducer_loss.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> transducer_loss_cuda_forward(\n    torch::Tensor x,\n    torch::Tensor label,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool packedInput);\n\ntorch::Tensor transducer_loss_cuda_backward(\n    torch::Tensor x,\n    torch::Tensor lossGrad,\n    torch::Tensor alpha,\n    torch::Tensor beta,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor label,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool fuseSoftmaxBackward,\n    bool packedInput);\n\n\nstd::vector<torch::Tensor> transducer_loss_forward(\n    torch::Tensor x,\n    torch::Tensor label,\n    torch::Tensor fLen,\n    torch::Tensor yLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool packedInput\n    ) {\n\n    CHECK_INPUT(x);\n    CHECK_INPUT(label);\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(yLen);\n    if (packedInput)\n        CHECK_INPUT(batchOffset);\n    return transducer_loss_cuda_forward(\n        x, \n        label, \n        fLen, \n        yLen, \n        batchOffset,\n        maxFLen,\n        blankIdx, \n        opt,\n        packedInput);\n}\n\ntorch::Tensor transducer_loss_backward(\n    torch::Tensor x,\n    torch::Tensor lossGrad,\n    torch::Tensor alpha,\n    torch::Tensor beta,\n    torch::Tensor fLen,\n    torch::Tensor yLen,\n    torch::Tensor label,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool fuseSoftmaxBackward,\n    bool packedInput){\n\n    CHECK_INPUT(x);\n    CHECK_INPUT(label);\n    CHECK_INPUT(lossGrad);\n    CHECK_INPUT(alpha);\n    CHECK_INPUT(beta);\n    CHECK_INPUT(fLen);\n    CHECK_INPUT(yLen);\n    if (packedInput)\n        CHECK_INPUT(batchOffset);\n\n    return transducer_loss_cuda_backward(\n        x,\n        lossGrad,\n        alpha,\n        beta,\n        fLen,\n        yLen,\n        label,\n        batchOffset,\n        maxFLen,\n        blankIdx,\n        opt,\n        fuseSoftmaxBackward,\n        packedInput);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &transducer_loss_forward, \"transducer loss forward (CUDA)\");\n  m.def(\"backward\", &transducer_loss_backward, \"transducer loss backward (CUDA)\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/transducer/transducer_loss_kernel.cu",
    "content": "#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <vector>\n#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <THC/THC.h>\n#include <ATen/cuda/CUDAContext.h>\n\ntemplate<typename scalar_t>\n__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {\n    // standard log-sum-exp trick is used here to provide better numerical stability\n    return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b));\n}\n\n// Vanilla transducer loss function (i.e. forward-backward algorithm)\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n\n// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted\n// into log scale by the preceding log_softmax layer\n// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. \n// alpha and beta are of acc_t type, as they are essentially accumulators.\n\n// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into \n// [B_packed, H].\n// Don't-care region (t > audLen) or (u > txtLen) is removed.\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_forward(\n    const scalar_t* x,\n    const int* label,\n    const int* audLen,\n    const int* txtLen,\n    const int64_t* batchOffset,\n    int64_t dictSize,   // 64-bit indexing for data tensor\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    acc_t* alpha,\n    acc_t* beta,\n    scalar_t* loss) {\n\n    const int batch = blockIdx.y;\n    const int tid = threadIdx.x;\n    const auto myFLen = audLen[batch];\n    // Note that start of the sentence is added as 1 here\n    const auto myGLen = txtLen[batch] + 1;  \n    const auto myLabel = label + batch * (maxGLen-1);\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n    const scalar_t* myX = x + myBatchOffset * dictSize; \n    int u  = tid;\n\n    if (blockIdx.x == 0){\n        // alpha path\n        acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;\n        if (u == 0) \n            myAlpha[0] = 0;\n        __syncthreads();\n\n        for (int64_t step = 1; step < myFLen+myGLen-1; ++step){\n            // Move along the diagonal wavefront to leverage available parallelism\n            for (u = tid; u < myGLen; u += blockDim.x){\n                int64_t t = step - u;\n                if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                    // Eq(16) in [1]\n                    if (u == 0){\n                        // alpha(t, u) = alpha(t-1, u) * null(t-1, u)\n                        myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen] \n                                                    + myX[((t-1)*myStrideT) * dictSize + blankIdx];\n                    }\n                    else if (t == 0){\n                        // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)\n                        myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];\n                    }\n                    else{\n                        // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)\n                        acc_t current = myAlpha[(t-1)*maxGLen + u] \n                                        + myX[((t-1)*myStrideT + u) * dictSize + blankIdx];\n                        acc_t next = myAlpha[t*maxGLen + u - 1] \n                                        + myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]];\n                        myAlpha[t*maxGLen + u] = logSumExp(next, current);\n                    }\n                }\n            }\n            __syncthreads();\n        }\n    }\n    else if (blockIdx.x == 1){\n        // beta path\n        acc_t* myBeta = beta + batch*maxFLen*maxGLen;\n        if (u == 0){\n            myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT \n                                                        + myGLen - 1) * dictSize + blankIdx];\n        }\n        __syncthreads();\n\n        for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){\n            for (u = tid; u < myGLen; u += blockDim.x){\n                int64_t t = step - u;\n                if (t >= 0 and t < myFLen and u >=0 and u < myGLen){\n                    // Eq(18) in [1]\n                    if (u == myGLen - 1){\n                        // beta(t, u) = beta(t+1, u) * null(t, u)\n                        myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u] \n                                                + myX[(t*myStrideT + u) * dictSize + blankIdx];\n                    }\n                    else if (t == myFLen - 1){\n                        // beta(t, u) = beta(t, u+1) * y(t, u)\n                        myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1] \n                                                + myX[(t*myStrideT + u) * dictSize + myLabel[u]];\n                    }\n                    else{\n                        // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)\n                        acc_t current = myBeta[(t+1)*maxGLen + u] \n                                        + myX[(t*myStrideT + u) * dictSize + blankIdx];\n                        acc_t next = myBeta[t*maxGLen + u + 1] \n                                        + myX[(t*myStrideT + u) * dictSize + myLabel[u]];\n                        myBeta[t*maxGLen + u] = logSumExp(next, current);\n                    }\n                }\n            }\n            __syncthreads();\n        }\n        if (tid == 0)\n            loss[batch] = -myBeta[0];   \n    }\n\n}\n\n// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.\n// Compared to the vanilla version, there are two optimizations:\n// 1. load x in batch through loop unrolling to reduce the latency.\n// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.\n// For simplicity, this kernel currently only supports U <= maxThread, which should be the common\n// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.\n\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted\n// into log scale by the preceding log_softmax layer\n// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.\n// alpha and beta are of acc_t type, as they are essentially accumulators.\n\n// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into \n// [B_packed, H].\n// Don't-care region (t > audLen) or (u > txtLen) is removed.\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t, int batchLdSize>\n__global__ void transducer_loss_batch_load_forward(\n    const scalar_t* x,\n    const int* label,\n    const int* audLen,\n    const int* txtLen,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    acc_t* alpha,\n    acc_t* beta,\n    scalar_t* loss) {\n\n    const int batch = blockIdx.y;\n    int u  = threadIdx.x;\n    const auto myFLen = audLen[batch];\n    const auto myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n    const scalar_t* myX = x + myBatchOffset * dictSize; \n    scalar_t next[batchLdSize], current[batchLdSize];\n    extern __shared__ char smem8[];\n    auto smem = reinterpret_cast<acc_t*>(smem8);\n\n    if (blockIdx.x == 0){\n        // alpha path\n        acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;\n        // two SMEM regions for double buffering read and write data to avoid data race\n        acc_t * const sharedAlpha[2] = {smem, smem+maxGLen};\n\n        sharedAlpha[0][u] = 0; \n        __syncthreads();\n\n        if (u == 0)\n            myAlpha[0] = 0;\n\n        auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1];\n        // register used to pass value to the next step for the same thread\n        acc_t prvStepAlpha = 0;\n        for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){\n            // Move along the diagonal wavefront to leverage available parallelism\n            // Batch loading X through loop unrolling\n            #pragma unroll\n            for (int i = 0; i < batchLdSize; ++i){\n                if (step+i<myFLen+myGLen-1){\n                    // index computing\n                    int64_t t = step + i - u;\n                    int64_t currentId = ((t-1)*myStrideT + u) * dictSize + blankIdx;\n                    int64_t nextId = (t*myStrideT + u - 1) * dictSize + myAlphaLabel;\n                    // main loading loop\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        if (u == 0){\n                            current[i] = myX[currentId];\n                        }\n                        else if (t == 0){\n                            next[i] = myX[nextId];\n                        }\n                        else{\n                            current[i] = myX[currentId];\n                            next[i] = myX[nextId];\n                        }\n                    }\n                }\n            }\n            // main computing loop\n            for (int i = 0; i < batchLdSize; ++i){\n                // swap the pointer for double buffering\n                auto sharedAlphaRd = sharedAlpha[(step+i-1)%2];\n                auto sharedAlphaWr = sharedAlpha[(step+i)%2];\n                if (step+i<myFLen+myGLen-1){\n                    int64_t t = step + i - u;\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        // Eq(16) in [1]\n                        if (u == 0)\n                            prvStepAlpha = prvStepAlpha+current[i];\n                        else if (t == 0)\n                            prvStepAlpha = sharedAlphaRd[u-1]+next[i];\n                        else\n                            prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1]\n                                            + next[i]);\n                        sharedAlphaWr[u] = prvStepAlpha;\n                        myAlpha[t*maxGLen + u] = prvStepAlpha;\n                    }\n                }\n                __syncthreads();\n            }\n        }\n    }\n    else if (blockIdx.x == 1){\n        // beta path\n        acc_t* myBeta = beta + batch*maxFLen*maxGLen;\n        // two SMEM regions for double buffering read and write data to avoid data race\n        acc_t * const sharedBeta[2] = {smem, smem + maxGLen};\n        sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];\n        __syncthreads();\n\n        auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u];\n        // register used to pass value to the next step for the same thread\n        acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];\n        if (u == 0)\n            myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta;\n\n        for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){\n            // Move along the diagonal wavefront to leverage available parallelism\n            // Batch loading X\n            #pragma unroll\n            for (int i = 0; i < batchLdSize; ++i){\n                if (step+i<myFLen+myGLen-1){\n                    // index computing\n                    int64_t t = myFLen+myGLen - (step + i) - 2 - u;\n                    int64_t currentId = (t*myStrideT + u) * dictSize + blankIdx;\n                    int64_t nextId = (t*myStrideT + u) * dictSize + myBetaLabel;\n                    // main loading loop\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        if (u == myGLen - 1){\n                            current[i] = myX[currentId];\n                        }\n                        else if (t == myFLen - 1){\n                            next[i] = myX[nextId];\n                        }\n                        else{\n                            current[i] = myX[currentId];\n                            next[i] = myX[nextId];\n                        }\n                    }\n                }\n            }\n            // main computing loop\n            for (int i = 0; i < batchLdSize; ++i){\n                // swap the pointer for double buffering\n                auto sharedBetaRd = sharedBeta[(step+i-1)%2];\n                auto sharedBetaWr = sharedBeta[(step+i)%2];\n                if (step+i<myFLen+myGLen-1){\n                    int64_t t = myFLen+myGLen - (step + i) - 2 - u;\n                    if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){\n                        // Eq(18) in [1]\n                        if (u == myGLen - 1)\n                            prvStepBeta = prvStepBeta+current[i];\n                        else if (t == myFLen - 1)\n                            prvStepBeta = sharedBetaRd[u+1]+next[i];\n                        else\n                            prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1]\n                                            + next[i]);\n                        sharedBetaWr[u] = prvStepBeta;\n                        myBeta[t*maxGLen + u] = prvStepBeta;\n                    }\n                    \n                }\n                __syncthreads();\n            }\n        }\n        if (u == 0)\n            loss[batch] = -prvStepBeta; \n    }\n\n}\n\n// Vanilla transudcer loss backward operation.\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, \n// hence only Eq(20) in [1] is implemented in this kernel.\n\n// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time\n// Since only gradients for the correct token and null token need to be updated, gradients at other\n// locations are initialized to 0.\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_backward(\n    const scalar_t* x,\n    const scalar_t* lossGrad,\n    const int* audLen,\n    const int* txtLen,\n    const int* label,\n    const acc_t* alpha,\n    const acc_t* beta,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    scalar_t* xGrad) {\n\n    const int tid = threadIdx.x;\n    const int t = blockIdx.x;\n    const int batch = blockIdx.y;\n    const int64_t myFLen = audLen[batch];\n    const int64_t myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n    auto myX = x + (myBatchOffset + t*myStrideT)*dictSize;\n    auto myAlpha = alpha + batch*maxFLen*maxGLen;\n    auto myBeta = beta + batch*maxFLen*maxGLen;\n    auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize; \n    auto myLabel = label + batch*(maxGLen-1);\n\n    int64_t u = tid;\n    while (t < myFLen and u < myGLen){\n        // Do the update\n        // loss = -ln(Pr(y*|x))\n        acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];  \n        if (u != myGLen - 1)\n            myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1] \n                                                + myX[u*dictSize + myLabel[u]]);\n        if (t == myFLen - 1 and u == myGLen - 1)\n            myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]);\n        else if (t != myFLen - 1)\n            myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u] \n                                                + myX[u*dictSize + blankIdx]); \n\n        u += blockDim.x;\n    }\n}\n\n// Fused transudcer loss backward operation.\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// The bwd op of the preceding softmax layer is fused in this kernel. \n// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_fused_backward(\n    const scalar_t* x,\n    const scalar_t* lossGrad,\n    const int* audLen,\n    const int* txtLen,\n    const int* label,\n    const acc_t* alpha,\n    const acc_t* beta,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    scalar_t* xGrad) {\n    \n    const int tid = threadIdx.x;\n    const int u = blockIdx.x;\n    const int t = blockIdx.y;\n    const int batch = blockIdx.z;\n    const int64_t myFLen = audLen[batch];\n    const int64_t myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n\n    __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;\n    auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; \n\n    if (t < myFLen and u < myGLen){ \n        auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; \n        auto myAlpha = alpha + batch*maxFLen*maxGLen;\n        auto myBeta = beta + batch*maxFLen*maxGLen;\n        auto myLabel = label + batch*(maxGLen-1);\n\n        // load and store shared variables in SMEM\n        if (tid == 0){\n            commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];\n            myBetaTU = myBeta[t*maxGLen + u];\n            myBetaTUp1 = myBeta[t*maxGLen + u + 1];\n            myBetaTp1U = myBeta[(t+1)*maxGLen + u];\n            myLabelShared = myLabel[u];\n        }\n\n        __syncthreads();\n\n        for (int64_t h = tid; h < dictSize; h += blockDim.x){\n            // Do the update\n            acc_t grad = commonFactor + myX[h];  // loss = -ln(Pr(y*|x))\n            acc_t myGrad = std::exp(grad + myBetaTU);\n            if (u != myGLen - 1 and h == myLabelShared){\n                myGrad -= std::exp(grad + myBetaTUp1);\n            }\n            else if (h == blankIdx){\n                if (t == myFLen - 1 and u == myGLen - 1)\n                    myGrad -= std::exp(grad);\n                else if (t != myFLen - 1)\n                    myGrad -= std::exp(grad + myBetaTp1U);\n            }\n            myXGrad[h] = myGrad;\n        }\n    }\n    else if (!packedInput){\n        // In non-pack mode, need to make sure the gradients for don't-care regions are zero.\n        for (int64_t h = tid; h < dictSize; h += blockDim.x){\n            myXGrad[h] = 0;\n        }\n    }\n}\n\n\n// Vectorized version of fused transudcer loss backward operation.\n// Detail of this loss function can be found in: \n// [1] Sequence Transduction with Recurrent Neural Networks.\n// The bwd op of the preceding softmax layer is fused in this kernel. \n// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V>\n__global__ void transducer_loss_fused_vec_backward(\n    const scalar_t* x,\n    const scalar_t* lossGrad,\n    const int* audLen,\n    const int* txtLen,\n    const int* label,\n    const acc_t* alpha,\n    const acc_t* beta,\n    const int64_t* batchOffset,\n    int64_t dictSize,\n    int64_t blankIdx,\n    int64_t maxFLen,\n    int64_t maxGLen,\n    bool packedInput,\n    scalar_t* xGrad) {\n    \n    const int tid = threadIdx.x;\n    const int u = blockIdx.x;\n    const int t = blockIdx.y;\n    const int batch = blockIdx.z;\n    const int64_t myFLen = audLen[batch];\n    const int64_t myGLen = txtLen[batch] + 1;\n    const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) \n                                                : batch * maxFLen * maxGLen;\n    const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n\n    __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;\n    auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; \n    auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; \n    auto myAlpha = alpha + batch*maxFLen*maxGLen;\n    auto myBeta = beta + batch*maxFLen*maxGLen;\n    auto myLabel = label + batch*(maxGLen-1);\n\n    // Variabels for vectorization\n    scalar_t myXBuffer[V], myXGradBuffer[V];\n    auto myXVec = reinterpret_cast<vec_t const *>(myX);\n    auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);\n    auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);\n    auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);\n    if (t < myFLen and u < myGLen){ \n        // load and store shared variables in SMEM\n        if (tid == 0){\n            commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];\n            myBetaTU = myBeta[t*maxGLen + u];\n            if (t != myFLen - 1)\n                myBetaTp1U = myBeta[(t+1)*maxGLen + u];\n            if (u != myGLen - 1){\n                myBetaTUp1 = myBeta[t*maxGLen + u + 1];\n                myLabelShared = myLabel[u];\n            }\n        }\n\n        __syncthreads();\n\n        #pragma unroll\n        for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){\n            // Load myX in a vector form\n            *myXBufferVec = myXVec[h0/V];\n            // Do the update for a vector of input\n            #pragma unroll\n            for (int i = 0; i < V; ++i){\n                auto h = h0 + i;\n                acc_t grad = commonFactor + myXBuffer[i];  // loss = -ln(Pr(y*|x))\n                acc_t myGrad = std::exp(grad + myBetaTU);\n                if (u != myGLen - 1 and h == myLabelShared){\n                    myGrad -= std::exp(grad + myBetaTUp1);\n                }\n                else if (h == blankIdx){\n                    if (t == myFLen - 1 and u == myGLen - 1)\n                        myGrad -= std::exp(grad);\n                    else if (t != myFLen - 1)\n                        myGrad -= std::exp(grad + myBetaTp1U);\n                }\n                myXGradBuffer[i] = myGrad;\n            }\n\n            // Store myXGrad in a vector form\n            myXGradVec[h0/V] = *myXGradBufferVec;\n            \n        }\n    }\n    else if (!packedInput){\n        // In non-pack mode, need to make sure the gradients for don't-care regions are zero.\n        for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){\n            myXGradVec[h0/V] = 0;\n        }\n    }\n}\n\n\nstd::vector<torch::Tensor> transducer_loss_cuda_forward(\n    torch::Tensor x,\n    torch::Tensor label,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool packedInput){\n\n    auto scalarType = x.scalar_type();\n    auto tensorOpt = x.options();\n    const int batchSize = label.size(0);\n    const int maxGLen = label.size(1) + 1;\n    const int dictSize = x.size(-1);\n\n    TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, \n                \"Expected blank index to be in the range of 0 to \", \n                dictSize-1,\n                \", but got \", \n                blankIdx);\n    TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, \n                \"Got an invalid optimization level \", \n                opt);\n\n    // The data type of alpha and beta will be resolved at dispatch time,\n    // hence defined here and assigned later\n    torch::Tensor alpha;    \n    torch::Tensor beta;\n    torch::Tensor loss = torch::empty({batchSize}, tensorOpt);\n    const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n    const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;\n    const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;\n    const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, \"transducer_loss_cuda_forward\", ([&] {\n        // resolve accumulation type\n        using acc_t = at::acc_type<scalar_t, true>;\n        auto accType = c10::CppTypeToScalarType<acc_t>::value;\n        auto accTensorOpt = tensorOpt.dtype(accType);\n        alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);\n        beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);\n\n        // decide what kernel to launch based on the problem size\n        // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla\n        // kernel.\n        const auto smemSize = 2*maxGLen*sizeof(acc_t);\n        const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 \n                                    : (opt == -1) ? 1 : opt;\n        const int threads = std::min(maxThreadPerBlock, maxGLen);\n        const dim3 blocks(2, batchSize, 1);        \n\n        if (optFallBack == 0)\n            transducer_loss_forward<<<blocks, threads, 0, stream>>>(\n                x.data_ptr<scalar_t>(), \n                label.data_ptr<int>(), \n                audLen.data_ptr<int>(), \n                txtLen.data_ptr<int>(), \n                batchOffsetPtr,\n                dictSize, \n                blankIdx, \n                maxFLen,\n                maxGLen,\n                packedInput,\n                alpha.data_ptr<acc_t>(), \n                beta.data_ptr<acc_t>(), \n                loss.data_ptr<scalar_t>());\n        else if (optFallBack == 1)\n            transducer_loss_batch_load_forward<scalar_t, acc_t, 4>\n            <<<blocks, threads, smemSize, stream>>>(\n                x.data_ptr<scalar_t>(), \n                label.data_ptr<int>(), \n                audLen.data_ptr<int>(), \n                txtLen.data_ptr<int>(), \n                batchOffsetPtr,\n                dictSize, \n                blankIdx, \n                maxFLen,\n                maxGLen,\n                packedInput,\n                alpha.data_ptr<acc_t>(), \n                beta.data_ptr<acc_t>(), \n                loss.data_ptr<scalar_t>());  \n\n    }));\n    THCudaCheck(cudaGetLastError());\n\n    return {alpha, beta, loss};\n}\n\n\n\n\ntorch::Tensor transducer_loss_cuda_backward(\n    torch::Tensor x,\n    torch::Tensor lossGrad,\n    torch::Tensor alpha,\n    torch::Tensor beta,\n    torch::Tensor audLen,\n    torch::Tensor txtLen,\n    torch::Tensor label,\n    torch::Tensor batchOffset,\n    int maxFLen,\n    int blankIdx,\n    int opt,\n    bool fuseSoftmaxBackward,\n    bool packedInput){\n\n    auto dtype = x.scalar_type();\n    torch::Tensor xGrad;\n    const int batchSize = label.size(0);\n    const int maxGLen = label.size(1) + 1;\n    const int dictSize = x.size(-1);\n    const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n    const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;\n    const int warpSize = deviceProperties->warpSize;\n    const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    \n    if (fuseSoftmaxBackward){\n        // alloc empty tensors for performance, hence need to ensure zeros are writtern to \n        // don't-care region in the kernel.\n        xGrad = torch::empty_like(x);\n\n        // Would like each thread to work on 4 hidden units\n        const int workPerThread = 4;  \n        // Don't want to have more than 128 threads per thread block\n        const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);\n        const int threads = std::min(maxThreadPerElmt, std::max(warpSize, \n                                    (dictSize+workPerThread-1)/workPerThread));\n        const dim3 blocks(maxGLen, maxFLen, batchSize);\n\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_loss_cuda_backward\", ([&] {\n            using vec_t = uint64_t;\n            using acc_t = at::acc_type<scalar_t, true>;\n            constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);\n            constexpr int vecAlignment = std::alignment_of<vec_t>::value;\n            // if all input and output tensors meet the alignment requirement\n            bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0\n                                and reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>()) \n                                        % vecAlignment == 0;\n\n            if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){\n                transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor>\n                    <<<blocks, threads, 0, stream>>>(    \n                    x.data_ptr<scalar_t>(), \n                    lossGrad.data_ptr<scalar_t>(),\n                    audLen.data_ptr<int>(), \n                    txtLen.data_ptr<int>(), \n                    label.data_ptr<int>(),\n                    alpha.data_ptr<acc_t>(), \n                    beta.data_ptr<acc_t>(),  \n                    batchOffsetPtr,\n                    dictSize, \n                    blankIdx, \n                    maxFLen,\n                    maxGLen,\n                    packedInput,\n                    xGrad.data_ptr<scalar_t>());   \n            }\n            else{\n                transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(    \n                    x.data_ptr<scalar_t>(), \n                    lossGrad.data_ptr<scalar_t>(),\n                    audLen.data_ptr<int>(), \n                    txtLen.data_ptr<int>(), \n                    label.data_ptr<int>(),\n                    alpha.data_ptr<acc_t>(), \n                    beta.data_ptr<acc_t>(),  \n                    batchOffsetPtr,\n                    dictSize, \n                    blankIdx, \n                    maxFLen,\n                    maxGLen,\n                    packedInput,\n                    xGrad.data_ptr<scalar_t>());   \n                \n            }\n        }));\n    }\n    else{\n        // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize\n        // the tensor with all zeros.\n        xGrad = torch::zeros_like(x);\n        // don't launch more threads than needed.\n        const int threads = std::min(maxThreadPerBlock, maxGLen);\n        const dim3 blocks(maxFLen, batchSize);\n        AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_loss_cuda_backward\", ([&] {\n            using acc_t = at::acc_type<scalar_t, true>;\n            transducer_loss_backward<<<blocks, threads, 0, stream>>>(    \n                x.data_ptr<scalar_t>(), \n                lossGrad.data_ptr<scalar_t>(),\n                audLen.data_ptr<int>(), \n                txtLen.data_ptr<int>(), \n                label.data_ptr<int>(),\n                alpha.data_ptr<acc_t>(), \n                beta.data_ptr<acc_t>(), \n                batchOffsetPtr, \n                dictSize, \n                blankIdx, \n                maxFLen,\n                maxGLen,\n                packedInput,\n                xGrad.data_ptr<scalar_t>());\n        }));\n    }\n    THCudaCheck(cudaGetLastError());\n    \n    return xGrad;\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/xentropy/interface.cpp",
    "content": "#include <torch/extension.h>\n\n// CUDA forward declarations\n\nstd::vector<at::Tensor> softmax_xentropy_cuda(\n    const at::Tensor &input,\n    const at::Tensor &labels,\n    const float smoothing,\n    const bool half_to_float);\n\nat::Tensor softmax_xentropy_backward_cuda(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing);\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<at::Tensor> softmax_xentropy_forward(\n    const at::Tensor &input,\n    const at::Tensor &labels,\n    const float smoothing,\n    const bool half_to_float) {\n    CHECK_CUDA(input);\n    CHECK_INPUT(labels);\n\n    return softmax_xentropy_cuda(input, labels, smoothing, half_to_float);\n}\n\nat::Tensor softmax_xentropy_backward(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing)  {\n    CHECK_CUDA(grad_loss);\n    CHECK_CUDA(logits);\n    CHECK_INPUT(max_log_sum_exp);\n    CHECK_INPUT(labels);\n\n    return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"forward\", &softmax_xentropy_forward, \"Softmax cross entropy loss with label smoothing forward (CUDA)\");\n    m.def(\"backward\", &softmax_xentropy_backward, \"Softmax cross entropy loss with label smoothing backward (CUDA)\");\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/csrc/xentropy/xentropy_kernel.cu",
    "content": "/**\n * From PyTorch:\n *\n * Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n * Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n * Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n * Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n * Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n *\n * From Caffe2:\n *\n * Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n *\n * All contributions by Facebook:\n * Copyright (c) 2016 Facebook Inc.\n *\n * All contributions by Google:\n * Copyright (c) 2015 Google Inc.\n * All rights reserved.\n *\n * All contributions by Yangqing Jia:\n * Copyright (c) 2015 Yangqing Jia\n * All rights reserved.\n *\n * All contributions from Caffe:\n * Copyright(c) 2013, 2014, 2015, the respective contributors\n * All rights reserved.\n *\n * All other contributions:\n * Copyright(c) 2015, 2016 the respective contributors\n * All rights reserved.\n *\n * Caffe2 uses a copyright model similar to Caffe: each contributor holds\n * copyright over their contributions to Caffe2. The project versioning records\n * all such contribution and copyright details. If a contributor wants to further\n * mark their specific copyright on a particular contribution, they should\n * indicate their copyright solely in the commit message of the change when it is\n * committed.\n *\n * All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n *\n * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n *    and IDIAP Research Institute nor the names of its contributors may be\n *    used to endorse or promote products derived from this software without\n *    specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n * POSSIBILITY OF SUCH DAMAGE.\n */\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/NumericLimits.cuh>\n\n#include <THC/THC.h>\n#include <THC/THCGeneral.h>\n#include <THC/THCThrustAllocator.cuh>\n\n#include \"type_shim.h\"\n#include \"compat.h\"\n\n#define ALIGN_BYTES 16\n\nusing Tensor = at::Tensor;\nusing TensorList = at::TensorList;\nusing ScalarType = at::ScalarType;\nusing at::acc_type;\n\ntemplate<typename T, typename AccumT, typename OutT>\nstruct LogSoftMaxForwardEpilogue {\n  __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)\n    : logsum(max_input + std::log(sum)) {}\n\n  __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)\n    : logsum(max_log_sum_exp) {}\n\n  __device__ __forceinline__ OutT operator()(T input) const {\n    return static_cast<OutT>(input - logsum);\n  }\n\n  const AccumT logsum;\n};\n\ntemplate<typename T, typename AccumT, typename OutT>\nstruct LogSoftMaxBackwardEpilogue {\n  __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)\n    : sum(sum) {}\n\n  __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {\n    return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);\n  }\n\n  const AccumT sum;\n};\n\n\n\nconst int max_threads = 1024;\n\ninline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {\n  uint64_t block_size = 1;\n  uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));\n  while (block_size < (max_block_size/2)) block_size *= 2;\n  // Launch at least a single warp - the kernel assumes that.\n  block_size = std::max(block_size, static_cast<uint64_t>(32));\n  return dim3(block_size);\n}\n\ntemplate<typename T>\nstruct Add {\n  __device__ __forceinline__ T operator()(T a, T b) const {\n    return a + b;\n  }\n};\n\ntemplate<typename T>\nstruct Max {\n  __device__ __forceinline__ T operator()(T a, T b) const {\n    return a < b ? b : a;\n  }\n};\n\n\n////////////////////////////////////////////////////////////////////////////////\n// Regular kernel (fast when dim_size is large; requires inner_size == 1)\n////////////////////////////////////////////////////////////////////////////////\n\n\ntemplate <typename T, typename AccumT>\nstruct MaxFloat\n{\n  __device__ __forceinline__ AccumT operator()(AccumT max, T v) const {\n    return ::max(max, (AccumT)v);\n  }\n};\n\ntemplate<typename T, typename AccumT>\nstruct AddFloat\n{\n  __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {\n    return sum + v;\n  }\n};\n\ntemplate<typename T, typename AccumT>\nstruct SumExpFloat\n{\n  __device__ __forceinline__ SumExpFloat(AccumT v)\n    : max_k(v) {}\n\n  __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {\n    return sum + std::exp(v - max_k);\n  }\n\n  const AccumT max_k;\n};\n\ntemplate <template<typename> class Reduction, typename AccumT>\n__device__ __forceinline__ AccumT\nblockReduce(AccumT* smem, AccumT val,\n            const Reduction<AccumT>& r,\n            AccumT defaultVal)\n{\n  // To avoid RaW races from chaining blockReduce calls together, we need a sync here\n  __syncthreads();\n\n  smem[threadIdx.x] = val;\n\n  __syncthreads();\n\n  AccumT warpVal = defaultVal;\n\n  // First warp will perform per-warp reductions for the remaining warps\n  uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;\n  if (threadIdx.x < 32) {\n    int lane = threadIdx.x % 32;\n    if (lane < blockDim.x / 32) {\n#pragma unroll\n      for (int i = 0; i < 32; ++i) {\n        warpVal = r(warpVal, smem[lane * 32 + i]);\n      }\n      __syncwarp(mask);\n      smem[lane] = warpVal;\n    }\n  }\n\n  __syncthreads();\n\n  // First thread will perform a reduction of the above per-warp reductions\n  AccumT blockVal = defaultVal;\n\n  if (threadIdx.x == 0) {\n    for (int i = 0; i < blockDim.x / 32; ++i) {\n      blockVal = r(blockVal, smem[i]);\n    }\n    smem[0] = blockVal;\n  }\n\n  // Sync and broadcast\n  __syncthreads();\n  return smem[0];\n}\n\ntemplate <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>\n__device__ __forceinline__ void\nblockReduce(AccumT* smem,\n            AccumT* reducVal1,\n            AccumT val1,\n            const Reduction1<AccumT>& r1,\n            AccumT defaultVal1,\n            AccumT* reducVal2,\n            AccumT val2,\n            const Reduction2<AccumT>& r2,\n            AccumT defaultVal2)\n{\n  // To avoid RaW races from chaining blockReduce calls together, we need a sync here\n  __syncthreads();\n\n  smem[threadIdx.x] = val1;\n  smem[blockDim.x + threadIdx.x] = val2;\n\n  __syncthreads();\n\n  AccumT warpVal1 = defaultVal1;\n  AccumT warpVal2 = defaultVal2;\n\n  // First warp will perform per-warp reductions for the remaining warps\n  uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;\n  if (threadIdx.x < 32) {\n    int lane = threadIdx.x % 32;\n    if (lane < blockDim.x / 32) {\n#pragma unroll\n      for (int i = 0; i < 32; ++i) {\n        warpVal1 = r1(warpVal1, smem[lane * 32 + i]);\n        warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);\n      }\n      __syncwarp(mask);\n      smem[lane] = warpVal1;\n      smem[lane + blockDim.x] = warpVal2;\n    }\n  }\n\n  __syncthreads();\n\n  // First thread will perform a reduction of the above per-warp reductions\n  AccumT blockVal1 = defaultVal1;\n  AccumT blockVal2 = defaultVal2;\n\n  if (threadIdx.x == 0) {\n    for (int i = 0; i < blockDim.x / 32; ++i) {\n      blockVal1 = r1(blockVal1, smem[i]);\n      blockVal2 = r2(blockVal2, smem[i + blockDim.x]);\n    }\n    smem[0] = blockVal1;\n    smem[blockDim.x] = blockVal2;\n  }\n\n  // Sync and broadcast\n  __syncthreads();\n  *reducVal1 = smem[0];\n  *reducVal2 = smem[blockDim.x];\n  __syncthreads();\n}\n\ntemplate <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>\n__device__ __forceinline__ AccumT\nilpReduce(int shift,\n          T* data,\n          int size,\n          const Reduction<T, AccumT>& r,\n          AccumT defaultVal)\n{\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;\n  AccumT threadVal = defaultVal;\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if(shift > 0){\n    data -= shift;\n    size += shift;\n    if(threadIdx.x >= shift){\n      threadVal = r(threadVal, data[offset]);\n    }\n    size -= blockDim.x;\n    data += blockDim.x;\n  }\n  int last = size % (ILP * blockDim.x);\n\n  T v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n\n  for (; offset * ILP < (size - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(data)[offset];\n\n    for (int j = 0; j < ILP; ++j) {\n      threadVal = r(threadVal, v[j]);\n    }\n  }\n\n  offset = size - last + threadIdx.x;\n  // Epilogue\n  for (; offset < size; offset += blockDim.x)\n    threadVal = r(threadVal, data[offset]);\n\n  return threadVal;\n}\n\ntemplate <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>\n__device__ __forceinline__ void\nilpReduce(int shift,\n          T* data,\n          int size,\n          AccumT* reducVal1,\n          const Reduction1<T, AccumT>& r1,\n          AccumT defaultVal1,\n          AccumT* reducVal2,\n          const Reduction2<T, AccumT>& r2,\n          AccumT defaultVal2)\n{\n  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;\n\n  AccumT threadVal1 = defaultVal1;\n  AccumT threadVal2 = defaultVal2;\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if(shift > 0){\n    data -= shift;\n    size += shift;\n    if(threadIdx.x >= shift){\n      threadVal1 = r1(threadVal1, data[offset]);\n      threadVal2 = r2(threadVal2, data[offset]);\n    }\n    size -= blockDim.x;\n    data += blockDim.x;\n  }\n  int last = size % (ILP * blockDim.x);\n\n  T v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n\n  for (; offset * ILP < (size - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(data)[offset];\n\n    for (int j = 0; j < ILP; ++j) {\n      threadVal1 = r1(threadVal1, v[j]);\n      threadVal2 = r2(threadVal2, v[j]);\n    }\n  }\n\n  offset = size - last + threadIdx.x;\n  // Epilogue\n  for (; offset < size; offset += blockDim.x) {\n    threadVal1 = r1(threadVal1, data[offset]);\n    threadVal2 = r2(threadVal2, data[offset]);\n  }\n\n  *reducVal1 = threadVal1;\n  *reducVal2 = threadVal2;\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>\n__global__ void\ncunn_SoftMaxXEntropyForward(\n    accscalar_t *losses,\n    outscalar_t *max_log_sum_exp,\n    scalar_t *input,\n    int64_t *labels,\n    int64_t classes,\n    const float smoothing)\n{\n  extern __shared__ unsigned char smem[];\n  auto sdata = reinterpret_cast<accscalar_t*>(smem);\n  // forward pointers to batch[blockIdx.x]\n  // each block handles a sample in the mini-batch\n  input += blockIdx.x * classes;\n  //output += blockIdx.x * classes;\n  const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);\n\n  int64_t label = labels[blockIdx.x];\n\n  // find the max and sum\n  accscalar_t threadMax, threadSum, max_k, sum_k;\n  ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(\n    shift, input, classes,\n    &threadMax, MaxFloat<scalar_t, accscalar_t>(),\n    -at::numeric_limits<accscalar_t>::max(),\n    &threadSum, AddFloat<scalar_t, accscalar_t>(),\n    static_cast<accscalar_t>(0));\n\n  blockReduce<Max, Add, accscalar_t>(\n      sdata,\n      &max_k, threadMax, Max<accscalar_t>(),\n      -at::numeric_limits<accscalar_t>::max(),\n      &sum_k, threadSum, Add<accscalar_t>(),\n      static_cast<accscalar_t>(0));\n\n  accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));\n  accscalar_t sumAll = blockReduce<Add, accscalar_t>(\n      sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));\n\n  Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);\n\n  // calculate per element loss with label smoothing\n  // reserve max + log_sum_exp for bprop\n  if (threadIdx.x == 0) {\n    accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));\n    losses[blockIdx.x] = (max_k + std::log(sumAll) - sum_k / classes) \\\n      * smoothing - log_prob * (1 - smoothing);\n    max_log_sum_exp[blockIdx.x] = max_k + std::log(sumAll);\n  }\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>\n__device__ __forceinline__ void\napply(scalar_t *gradInput,\n      scalar_t *logits,\n      outscalar_t *max_log_sum_exp,\n      outscalar_t *gradOutput,\n      int64_t *labels,\n      const float smoothing,\n      int classes)\n{\n  accscalar_t smooth_positives = 1.0 - smoothing;\n  accscalar_t smooth_negatives = smoothing / classes;\n  accscalar_t tmpGradOutput = gradOutput[blockIdx.x];\n  int64_t label = labels[blockIdx.x];\n  accscalar_t coeff = max_log_sum_exp[blockIdx.x];\n\n  int offset = threadIdx.x;\n  int last = classes % (ILP * blockDim.x);\n\n  for (; offset < classes - last; offset += blockDim.x * ILP) {\n    accscalar_t tmpLogits[ILP];\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j) {\n      tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);\n    }\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j)\n      gradInput[offset + j * blockDim.x] = tmpGradOutput * (\n        std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(\n          (offset + j * blockDim.x == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n  }\n\n  for (; offset < classes; offset += blockDim.x)\n    gradInput[offset] = tmpGradOutput * (std::exp(\n        static_cast<accscalar_t>(logits[offset]) - coeff) -\n        static_cast<accscalar_t>((offset == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n}\n\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>\n__device__ __forceinline__ void\naligned_apply(int shift,\n              scalar_t *gradInput,\n              scalar_t *logits,\n              outscalar_t *max_log_sum_exp,\n              outscalar_t *gradOutput,\n              int64_t *labels,\n              const float smoothing,\n              int classes)\n{\n  accscalar_t smooth_positives = 1.0 - smoothing;\n  accscalar_t smooth_negatives = smoothing / classes;\n  accscalar_t tmpGradOutput = gradOutput[blockIdx.x];\n  int64_t label = labels[blockIdx.x];\n  accscalar_t coeff = max_log_sum_exp[blockIdx.x];\n\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if(shift > 0){\n    logits -= shift;\n    gradInput -= shift;\n    classes += shift;\n    if(threadIdx.x >= shift){\n      gradInput[offset] = tmpGradOutput * (std::exp(\n        static_cast<accscalar_t>(logits[offset]) - coeff) -\n        static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n    }\n    classes -= blockDim.x;\n    gradInput += blockDim.x;\n    logits += blockDim.x;\n    shift -= blockDim.x;\n  }\n\n  int last = classes % (ILP * blockDim.x);\n\n  typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;\n  // input\n  scalar_t v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n  // output\n  scalar_t r[ILP];\n  LoadT* result = reinterpret_cast<LoadT*>(&r);\n\n  for (; offset * ILP < (classes - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(logits)[offset];\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j) {\n      r[j] = tmpGradOutput * (std::exp(\n          static_cast<accscalar_t>(v[j]) - coeff) -\n          static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *\n          smooth_positives - smooth_negatives);\n    }\n    reinterpret_cast<LoadT*>(gradInput)[offset] = *result;\n  }\n\n  offset = classes - last + threadIdx.x;\n  for (; offset < classes; offset += blockDim.x)\n    gradInput[offset] = tmpGradOutput * (std::exp(\n        static_cast<accscalar_t>(logits[offset]) - coeff) -\n        static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *\n        smooth_positives - smooth_negatives);\n\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>\n__global__ void\ncunn_SoftMaxXEntropyBackward(\n    scalar_t *gradInput,\n    scalar_t *logits,\n    outscalar_t *max_log_sum_exp,\n    outscalar_t *gradOutput,\n    int64_t *labels,\n    const float smoothing,\n    int classes)\n{\n  gradInput += blockIdx.x * classes;\n  logits += blockIdx.x * classes;\n\n  // Do vectorized load/store when input/output have same alignment\n  const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);\n  const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);\n  if (shift == shift_){\n    aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);\n  }\n  else {\n    apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);\n  }\n\n}\n\ntemplate<template<typename, typename, typename> class Epilogue>\nstd::vector<Tensor> host_softmax_xentropy(\n        const Tensor & input_,\n        const Tensor & labels_,\n        const float smoothing,\n        const bool half_to_float){\n  if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half,\"conversion is supported for Half type only\");\n  AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,\"Label type should be CUDA Long\");\n\n  auto input = input_.contiguous();\n  Tensor max_log_sum_exp = at::empty_like(labels_, half_to_float ? input.options().dtype(ScalarType::Float) : input.options());\n  Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));\n\n  static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||\n    std::is_same<acc_type<at::Half, true>, double>::value,\n    \"accscalar_t for half should be float or double\");\n  AT_ASSERTM(input.dim() == 2, \"Currently only 2 dim input supported\");\n  AT_ASSERTM(labels_.dim() == 1, \"Labels should be 1 dimensional\");\n  AT_ASSERTM(input.size(0) == labels_.size(0), \"Input and label should have same number of examples\");\n  AT_ASSERTM(input.numel() > 0, \"Number of classes in input should not be 0\");\n\n  const int64_t dim = 1;\n  int64_t outer_size = 1;\n  int64_t dim_size = input.size(dim);\n  int64_t inner_size = 1;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  for (int64_t i = 0; i < dim; ++i)\n    outer_size *= input.size(i);\n  for (int64_t i = dim + 1; i < input.dim(); ++i)\n    inner_size *= input.size(i);\n  // This kernel spawns a block per each element in the batch.\n  // XXX: it assumes that inner_size == 1\n  TORCH_CHECK(inner_size == 1, \"Currently only inner size 1 supported\");\n\n  dim3 grid(outer_size);\n\n  using namespace at;\n  DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, \"host_softmax_xentropy\",\n    using accscalar_t = at::acc_type<scalar_t_0, true>;\n    const int ILP = sizeof(float4)/sizeof(scalar_t_0);\n    dim3 block = SoftMax_getBlockSize(ILP, dim_size);\n    if (!half_to_float) {\n      cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>\n        <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(\n          losses.DATA_PTR<accscalar_t>(), max_log_sum_exp.DATA_PTR<scalar_t_0>(),\n          input.DATA_PTR<scalar_t_0>(), labels_.DATA_PTR<int64_t>(),\n          dim_size, smoothing\n      );\n    } else {\n      cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>\n        <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(\n          losses.DATA_PTR<accscalar_t>(), max_log_sum_exp.DATA_PTR<accscalar_t>(),\n          input.DATA_PTR<scalar_t_0>(), labels_.DATA_PTR<int64_t>(),\n          dim_size, smoothing\n      );\n    }\n  );\n\n  THCudaCheck(cudaGetLastError());\n\n  std::vector<at::Tensor> ret = {losses, max_log_sum_exp};\n  return ret;\n}\n\ntemplate<template<typename, typename, typename> class Epilogue>\nTensor host_softmax_xentropy_backward(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits_,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing,\n    bool half_to_float) {\n  const int64_t dim = 1;\n  Tensor gI = at::empty_like(logits_);\n  if (grad_loss.numel() == 0) {\n    return gI;\n  }\n\n  auto grad = grad_loss.contiguous();\n  auto logits = logits_.contiguous();\n\n  static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||\n    std::is_same<acc_type<at::Half, true>, double>::value,\n    \"accscalar_t for half should be float or double\");\n  if (grad.dim() == 0) grad = grad.view(1);\n\n  AT_ASSERTM(logits_.dim() == 2, \"Currently only 2 dim input supported\");\n  AT_ASSERTM(labels.dim() == 1, \"Labels should be 1 dimensional\");\n  AT_ASSERTM(logits_.numel() > 0, \"Number of classes in input should not be 0\");\n  AT_ASSERTM(logits_.size(0) == labels.size(0), \"Input and label should have same number of examples\");\n  AT_ASSERTM(labels.size(0) == grad.size(0), \"Label and loss should have same number of examples\");\n\n  int64_t outer_size = 1;\n  int64_t dim_size = logits.size(dim);\n  int64_t inner_size = 1;\n  for (int64_t i = 0; i < dim; ++i)\n    outer_size *= logits.size(i);\n  for (int64_t i = dim + 1; i < logits.dim(); ++i)\n    inner_size *= logits.size(i);\n  // See descriptions of kernels above.\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  TORCH_CHECK(inner_size == 1, \"Currently only inner size 1 supported\");\n\n  dim3 grid(outer_size);\n\n  DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, \"host_softmax_xentropy_backward\",\n    using accscalar_t = acc_type<scalar_t_0, true>;\n    const int ILP = sizeof(float4)/sizeof(scalar_t_0);\n    dim3 block = SoftMax_getBlockSize(ILP, dim_size);\n    if (!half_to_float) {\n      cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>\n       <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n          gI.DATA_PTR<scalar_t_0>(), logits.DATA_PTR<scalar_t_0>(),\n          max_log_sum_exp.DATA_PTR<scalar_t_0>(),\n          grad.DATA_PTR<scalar_t_0>(), labels.DATA_PTR<int64_t>(),\n          smoothing, dim_size\n      );\n    } else {\n      cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>\n       <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n          gI.DATA_PTR<scalar_t_0>(), logits.DATA_PTR<scalar_t_0>(),\n          max_log_sum_exp.DATA_PTR<accscalar_t>(),\n          grad.DATA_PTR<accscalar_t>(), labels.DATA_PTR<int64_t>(),\n          smoothing, dim_size\n      );\n    }\n  );\n\n  THCudaCheck(cudaGetLastError());\n  return gI;\n}\n\nstd::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const bool half_to_float){\n  return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, half_to_float);\n}\n\nat::Tensor softmax_xentropy_backward_cuda(\n    const at::Tensor &grad_loss,\n    const at::Tensor &logits,\n    const at::Tensor &max_log_sum_exp,\n    const at::Tensor &labels,\n    const float smoothing) {\n  bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType();\n  if (half_to_float) {\n     AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && logits.type().scalarType() == ScalarType::Half), \"expected input and grad types to match, or input to be at::Half and grad to be at::Float\");\n  }\n  return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float);\n}\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/examples/multihead_attn/func_test_multihead_attn.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nparser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')\nparser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')\nparser.add_argument('--num-seqs-start', default=5, type=int, help='Start Range of Number of Sequences')\nparser.add_argument('--num-seqs-stop', default=80, type=int, help='Stop Range of Number of Sequences')\nparser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')\nparser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')\nparser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')\nparser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--seed-start', default=1, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--seed-end', default=100, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')\nparser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')\nparser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')\nparser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')\nparser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')\nparser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')\nparser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')\nparser.add_argument('--eval', action='store_true', help='Inference only, no backward pass.')\n\nargs = parser.parse_args()\nassert args.seq_length % 64 == 0, \"Sequence Length should be a multiple of 64!\"\n\nif not torch.cuda.is_available():\n    raise NotImplementedError('Running on CPU is not supported')\ntorch.cuda.set_device(0)\n\ndropout_prob = 0.1\n\nfor seed in range(args.seed_start, args.seed_end+1) :\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    ref_layer = None\n    if args.encdec_attn :\n        ref_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')\n    else :\n        ref_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')\n    ref_layer.cuda()\n    ref_layer.half()\n    ref_layer.reset_parameters()\n\n    ref_inputs    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    ref_inputs_kv = None\n    if args.encdec_attn :\n        ref_inputs_kv    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    ref_grads         = torch.randn_like(ref_inputs)\n\n    ref_outputs,_ = ref_layer.forward(ref_inputs,\n                                      ref_inputs_kv,\n                                      ref_inputs_kv,\n                                      key_padding_mask=None,\n                                      need_weights=False,\n                                      attn_mask=None,\n                                      is_training=(not args.eval))\n\n    ref_outputs.backward(ref_grads)\n\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    tst_layer = None\n    if args.encdec_attn :\n        tst_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')\n    else:\n        tst_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')\n    tst_layer.cuda()\n    tst_layer.half()\n    tst_layer.reset_parameters()\n\n    tst_inputs    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    tst_inputs_kv = None\n    if args.encdec_attn :\n        tst_inputs_kv    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    assert torch.equal(ref_inputs,tst_inputs), \"ERROR: Inputs are different!\"\n\n    tst_grads         = torch.randn_like(tst_inputs)\n\n    tst_outputs,_ = tst_layer.forward(tst_inputs,\n                                      tst_inputs_kv,\n                                      tst_inputs_kv,\n                                      key_padding_mask=None,\n                                      need_weights=False,\n                                      attn_mask=None,\n                                      is_training=(not args.eval))\n\n    tst_outputs.backward(tst_grads)\n\n    fwd_close = torch.equal(ref_outputs, tst_outputs)\n    bwd_close = torch.equal(ref_inputs.grad, tst_inputs.grad)\n\n    diff_fwd = ref_outputs - tst_outputs\n    diff_cnt_fwd = diff_fwd.ne(0.0).sum()\n    diff_accum_fwd = diff_fwd.abs().sum()\n\n    diff_bwd = ref_inputs.grad - tst_inputs.grad\n    diff_cnt_bwd = diff_bwd.ne(0.0).sum()\n    diff_accum_bwd = diff_bwd.abs().sum()\n\n    print(\">>> Seed: \", seed, fwd_close, diff_cnt_fwd.item(), diff_accum_fwd.item(), bwd_close, diff_cnt_bwd.item(), diff_accum_bwd.item())\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nparser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')\nparser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')\nparser.add_argument('--num-seqs-start', default=10, type=int, help='Start Range of Number of Sequences')\nparser.add_argument('--num-seqs-stop', default=120, type=int, help='Stop Range of Number of Sequences')\nparser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')\nparser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')\nparser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')\nparser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')\nparser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')\nparser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')\nparser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')\nparser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')\nparser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')\nparser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')\nparser.add_argument('--biases', action='store_true', help='Execute multihead attention with Linear Biases.')\n\nargs = parser.parse_args()\n\nif not torch.cuda.is_available():\n    raise NotImplementedError('Running on CPU is not supported')\ntorch.cuda.set_device(0)\n\ntorch.manual_seed(111)\nif torch.cuda.is_available():\n    torch.cuda.manual_seed_all(111)\n\nattn_layers = []\nfor idx in range(0, args.layers) :\n    if args.encdec_attn :\n        if args.ref :\n            attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=False, impl='default'))\n        else :\n            attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))\n    else :\n        if args.native :\n            attn_layers.append(torch.nn.MultiheadAttention(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases))\n        elif args.ref :\n            attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='default'))\n        else :\n            attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))\n    attn_layers[idx].cuda()\n    attn_layers[idx].half()\n    if not args.native :\n        attn_layers[idx].reset_parameters()\n\nstart_evt_fwd = []\nstart_evt_bwd = []\nstop_evt_bwd  = []\nfor recorded_trial in range(0, args.trials) :\n    start_evt_fwd.append(torch.cuda.Event(enable_timing=True))\n    start_evt_bwd.append(torch.cuda.Event(enable_timing=True))\n    stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))\n\nfor sequences in range(args.num_seqs_start, args.num_seqs_stop + args.num_seqs_inc, args.num_seqs_inc) :\n    inputs        = torch.randn(args.seq_length, sequences, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    grads         = torch.randn_like(inputs)\n   \n    for trial in range(0, args.trials + args.warmup_trials) :\n        layer_inputs  = inputs\n        evt_idx       = trial - args.warmup_trials\n    \n        if evt_idx >= 0 :\n            start_evt_fwd[evt_idx].record()\n    \n        for lyr_idx in range(0, args.layers) :\n            if args.native :\n                outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, \n                                                         layer_inputs, \n                                                         layer_inputs, \n                                                         key_padding_mask=None, \n                                                         need_weights=False, \n                                                         attn_mask=None)\n            else :\n                outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, \n                                                         layer_inputs, \n                                                         layer_inputs,\n                                                         key_padding_mask=None, \n                                                         need_weights=False, \n                                                         attn_mask=None,\n                                                         is_training=True)\n            layer_inputs = outputs\n    \n        if evt_idx >= 0 :\n            start_evt_bwd[evt_idx].record()\n\n        if not args.fwd :\n            layer_inputs.backward(grads)\n    \n        if evt_idx >= 0 :\n            stop_evt_bwd[evt_idx].record()\n   \n    torch.cuda.synchronize()\n    elapsed_time_fwd = 0.0\n    elapsed_time_bwd = 0.0\n    for evt_idx in range(0, args.trials) :\n        elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])\n        elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])\n   \n    print(\"[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms\".format(\n        'Encdec' if args.encdec_attn else 'Self',              \\\n        'Norm&Add' if args.norm_add else '',                   \\\n        sequences*args.seq_length,                             \\\n        sequences,                                             \\\n        args.seq_length,                                       \\\n        elapsed_time_fwd / ( args.trials * args.layers ),      \\\n        elapsed_time_bwd / ( args.trials * args.layers )))\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/fmha/__init__.py",
    "content": "from .fmha import FMHAFun\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/fmha/fmha.py",
    "content": "###############################################################################\n# Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n# \n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#     * Redistributions of source code must retain the above copyright\n#       notice, this list of conditions and the following disclaimer.\n#     * Redistributions in binary form must reproduce the above copyright\n#       notice, this list of conditions and the following disclaimer in the\n#       documentation and/or other materials provided with the distribution.\n#     * Neither the name of the NVIDIA CORPORATION nor the\n#       names of its contributors may be used to endorse or promote products\n#       derived from this software without specific prior written permission.\n# \n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n#\n###############################################################################\n\n\nimport torch\nimport torch.nn.functional as F\nimport fmhalib as mha\n\nclass FMHAFun(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):\n        batch_size = cu_seqlens.numel() - 1\n        if batch_size < 4:\n            context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)\n        else:\n            context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)\n        ctx.save_for_backward(qkv, S_dmask)\n        ctx.cu_seqlens = cu_seqlens\n        ctx.p_dropout = p_dropout\n        ctx.max_s = max_s\n        return context\n    \n    @staticmethod\n    def backward(ctx, dout):\n        qkv, S_dmask = ctx.saved_tensors\n        batch_size = ctx.cu_seqlens.numel() - 1\n        if batch_size < 4:\n            dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)\n        else:\n            dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)\n\n        return dqkv, None, None, None, None, None, None\n\nclass FMHA(torch.nn.Module):\n\n    def __init__(self, config):\n\n        super(FMHA, self).__init__()\n\n        self.p_dropout = config.attention_probs_dropout_prob\n        self.h = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.d = self.hidden_size // self.h\n        assert self.d * self.h == self.hidden_size, \"Invalid hidden size/num_heads\"\n\n    def forward(self, qkv, cu_seqlens, max_s, is_training=True):\n\n        ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)\n\n        return ctx.view(-1, self.hidden_size)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/groupbn/__init__.py",
    "content": "try:\n    import torch\n    import bnp\n    from .batch_norm import BatchNorm2d_NHWC\n    del torch\n    del bnp\n    del batch_norm\nexcept ImportError as err:\n    print(\"apex was installed without --bnp flag, contrib.groupbn is not available\")\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/groupbn/batch_norm.py",
    "content": "import torch\nimport numpy as np\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nimport bnp\n\nclass bn_NHWC_impl(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):\n        if is_train:\n            ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)\n            ctx.epsilon = epsilon\n            ctx.momentum = mom\n            ctx.ret_cta = ret_cta\n            ctx.fuse_relu = fuse_relu\n            ctx.my_data = my_data\n            ctx.pair_data = pair_data\n            ctx.magic = magic\n            ctx.pair_data2 = pair_data2\n            ctx.pair_data3 = pair_data3\n            ctx.bn_group = bn_group\n            ctx.bwd_occup = bwd_occup\n            ctx.bwd_grid_x = bwd_grid_x\n            ctx.multi_stream = multi_stream\n\n            res =  bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)\n            return res\n        else:\n            return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu)\n\n    @staticmethod\n    def backward(ctx, grad_y):\n        x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables\n        epsilon = ctx.epsilon\n        mom = ctx.momentum\n        ret_cta = ctx.ret_cta\n        fuse_relu = ctx.fuse_relu\n        my_data = ctx.my_data\n        pair_data = ctx.pair_data\n        magic = ctx.magic\n        pair_data2 = ctx.pair_data2\n        pair_data3 = ctx.pair_data3\n        bn_group = ctx.bn_group\n        bwd_occup = ctx.bwd_occup\n        bwd_grid_x = ctx.bwd_grid_x\n        multi_stream = ctx.multi_stream\n\n        dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)\n\n        return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\nclass bn_addrelu_NHWC_impl(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):\n        if is_train:\n            bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)\n            ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)\n            ctx.epsilon = epsilon\n            ctx.momentum = mom\n            ctx.ret_cta = ret_cta\n            ctx.my_data = my_data\n            ctx.pair_data = pair_data\n            ctx.magic = magic\n            ctx.pair_data2 = pair_data2\n            ctx.pair_data3 = pair_data3\n            ctx.bn_group = bn_group\n            ctx.bwd_occup = bwd_occup\n            ctx.bwd_grid_x = bwd_grid_x\n            ctx.multi_stream = multi_stream\n\n            res =  bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)\n            return res\n        else:\n            return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon)\n\n    @staticmethod\n    def backward(ctx, grad_y):\n        x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables\n        epsilon = ctx.epsilon\n        mom = ctx.momentum\n        ret_cta = ctx.ret_cta\n        my_data = ctx.my_data\n        pair_data = ctx.pair_data\n        magic = ctx.magic\n        pair_data2 = ctx.pair_data2\n        pair_data3 = ctx.pair_data3\n        bn_group = ctx.bn_group\n        bwd_occup = ctx.bwd_occup\n        bwd_grid_x = ctx.bwd_grid_x\n        multi_stream = ctx.multi_stream\n\n        dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)\n\n        return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\n\n\n\nclass BatchNorm2d_NHWC(_BatchNorm):\n    # if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True\n    def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False):\n        super(BatchNorm2d_NHWC, self).__init__(num_features)\n\n        self.fuse_relu = fuse_relu\n        self.multi_stream = multi_stream\n\n        self.minibatch_mean = torch.cuda.FloatTensor(num_features)\n        self.minibatch_riv = torch.cuda.FloatTensor(num_features)\n\n        #defaut to distributed bn disabled\n        self.bn_group = bn_group\n        self.max_cta_per_sm = max_cta_per_sm        #used only in training fwd and bwd\n        self.cta_launch_margin = cta_launch_margin  #used only in training fwd and bwd\n        self.my_data = None\n        self.pair_data = None\n        self.pair_data2 = None\n        self.pair_data3 = None\n        self.local_rank = 0\n        self.magic = torch.IntTensor([0])\n\n        #calculate cta per sm occupancies\n        assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :)\n        self.fwd_occupancy =  min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm)\n        self.bwd_occupancy =  min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm)\n        self.addrelu_fwd_occupancy =  min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm)\n        self.addrelu_bwd_occupancy =  min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm)\n\n        #calculate grid dimentions based on occupancy numbers\n        mp_count = torch.cuda.get_device_properties(None).multi_processor_count\n        self.fwd_grid_dim_x = max(mp_count*self.fwd_occupancy - cta_launch_margin , 1)\n        self.bwd_grid_dim_x = max(mp_count*self.bwd_occupancy - cta_launch_margin , 1)\n        self.addrelu_fwd_grid_dim_x = max(mp_count*self.addrelu_fwd_occupancy - cta_launch_margin , 1)\n        self.addrelu_bwd_grid_dim_x = max(mp_count*self.addrelu_bwd_occupancy - cta_launch_margin , 1)\n        self.grid_dim_y = (num_features + 63) // 64\n\n        # allocate scratch space used by implementation\n        # TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the\n        # same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new\n        # buffer from cache allocator to avoid unnecessary initialization at future iterations.\n        self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0)\n\n        #FIXME: turn pair handles into an array\n        if bn_group>1:\n            local_rank = torch.distributed.get_rank()\n            world_size = torch.distributed.get_world_size()          \n            assert(world_size >= bn_group)\n            assert(world_size % bn_group == 0)\n         \n            bn_sync_steps = 1\n            if (bn_group==4):\n                bn_sync_steps = 2\n            if (bn_group==8):\n                bn_sync_steps = 3\n\n            self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))\n            self.my_data = bnp.get_data_ptr(self.ipc_buffer)\n            # we are walking on very thin ice here by utilizing internal `_share_cuda_()`\n            self.storage = self.ipc_buffer.storage()\n            self.share_cuda = self.storage._share_cuda_()\n            internal_cuda_mem = self.share_cuda\n            # internal_cuda_mem[1]: ipc_mem_handle\n            my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))\n            # internal_cuda_mem[3]: offset\n            my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])\n\n            handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device)\n            handles_l = list(handles_all.unbind(0))\n            torch.distributed.all_gather(handles_l, my_handle)\n\n            offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device)\n            offsets_l = list(offsets_all.unbind(0))\n            torch.distributed.all_gather(offsets_l, my_offset)\n\n            #whom do I actually care about? that would be local_rank XOR 1\n            self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()\n            pair_offset = offsets_l[local_rank ^ 1].cpu()\n            self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)\n\n            if bn_group>2:\n                self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()\n                pair_offset2 = offsets_l[local_rank ^ 2].cpu()\n                self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)\n\n            if bn_group>4:\n                self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()\n                pair_offset3 = offsets_l[local_rank ^ 4].cpu()\n                self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)\n\n            #FIXME: get magic value into C code and eliminate from here\n            self.magic = torch.IntTensor([2])\n            self.local_rank = local_rank\n\n\n    def forward(self, x, z=None):\n        if z is not None:\n            assert(self.fuse_relu==True)\n            return bn_addrelu_NHWC_impl.apply(x, z,\n                                  self.weight, self.bias,\n                                  self.running_mean, self.running_var,\n                                  self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta,\n                                  self.momentum,\n                                  self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,\n                                  self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x,\n                                  self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x,\n                                  self.multi_stream)\n        else:\n            return bn_NHWC_impl.apply(x,\n                                  self.weight, self.bias,\n                                  self.running_mean, self.running_var,\n                                  self.minibatch_mean, self.minibatch_riv, self.ret_cta,\n                                  self.momentum,\n                                  self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,\n                                  self.fwd_occupancy, self.fwd_grid_dim_x,\n                                  self.bwd_occupancy, self.bwd_grid_dim_x,\n                                  self.multi_stream)\n\n    def __del__(self):\n        if self.bn_group>1:\n          bnp.close_remote_data(self.pair_handle)\n          if self.bn_group>2:\n              bnp.close_remote_data(self.pair_handle2)\n              if self.bn_group>4:\n                 bnp.close_remote_data(self.pair_handle3)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/layer_norm/__init__.py",
    "content": "from .layer_norm import FastLayerNorm\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/layer_norm/layer_norm.py",
    "content": "import torch\nfrom torch.nn import init\n\nimport fast_layer_norm\n\nclass FastLayerNormFN(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, gamma, beta, epsilon):\n        x = x.contiguous()\n        gamma = gamma.contiguous()\n        beta = beta.contiguous()\n        hidden_size = gamma.numel()\n        xmat = x.view((-1, hidden_size))\n        ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)\n        ctx.save_for_backward(x, gamma, mu, rsigma)\n        return ymat.view(x.shape)\n    \n    @staticmethod\n    def backward(ctx, dy):\n        #assert dy.is_contiguous()\n        dy = dy.contiguous() # this happens!\n        x, gamma, mu, rsigma = ctx.saved_tensors\n\n        hidden_size = gamma.numel()\n        xmat = x.view((-1, hidden_size))\n        dymat = dy.view(xmat.shape)\n        dxmat, dgamma, dbeta = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)\n        dx = dxmat.view(x.shape)\n        return dx, dgamma, dbeta, None\n\nclass FastLayerNorm(torch.nn.Module):\n    def __init__(self, hidden_size, eps=1e-5):\n        super(FastLayerNorm, self).__init__()\n        self.epsilon = eps\n        self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))\n        self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init.ones_(self.weight)\n        init.zeros_(self.bias)\n\n    def forward(self, x):\n        return FastLayerNormFN.apply(x, self.weight, self.bias, self.epsilon)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/README.md",
    "content": "# Fast Multihead Attention \n\nThis implementation has two main features :\n* A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes.\n* The removal of all copies and transposes found in standard implementations of Multihead Attention.\n\n|                                            | Python Version | C++ Version |\n| :----------------------------------------- | :------------: | :---------: |\n| Layer Norm and Residual Add Variant        | X              | X           |\n| Includes Linear Biases                     | X              |             |\n| Reduces CPU Overheads                      |                | X           |\n| Fuses masking with Softmax                 |                | X           |\n| Removes Transposes and Copies              | X              | X           |\n| Includes Self and Encoder/Decoder Variants | X              | X           |\n\n## How to Instantiate\n\n`SelfMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`\n`EncdecMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`\n\n `impl` has two options:\n * `fast` uses C++ Version\n * `default` uses Python Version\n\n## Instructions to build on Linux\n\n```\n$ git clone https://github.com/NVIDIA/apex\n$ cd apex\n$ pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--fast_multihead_attn\" ./\n```\n## Try Performance Tests Yourself!\nPerf test script is found here!\n```\ncd contrib/examples/multihead_attn\n```\n#### Fast Multihead Attention\n```\npython perf_test_multihead_attn.py --ref\n```\n#### Fast Multihead Attention with C++ Implementation\n```\npython perf_test_multihead_attn.py\n```\n#### Compare with `torch.nn.MultiheadAttn`\n```\npython perf_test_multihead_attn.py --native\n```\n#### Test your own range!\n```\npython perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5\n```\n\n## Performance Comparisons\n\n* Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card.\n* Time is measured across multiple layers to simulate an in model scenario.\n\n![Multihead Attention Forward](MHA_fwd.png)\n![Multihead Attention Backward](MHA_bwd.png)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/__init__.py",
    "content": "from .self_multihead_attn import SelfMultiheadAttn\nfrom .encdec_multihead_attn import EncdecMultiheadAttn\nfrom .mask_softmax_dropout_func import fast_mask_softmax_dropout_func\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/encdec_multihead_attn.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom .encdec_multihead_attn_func               import encdec_attn_func\nfrom .fast_encdec_multihead_attn_func          import fast_encdec_attn_func\nfrom .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func\nfrom apex.normalization.fused_layer_norm       import FusedLayerNorm\n\nif hasattr(torch._C, '_jit_set_profiling_executor') :\n    torch._C._jit_set_profiling_executor(False)\nif hasattr(torch._C, '_jit_set_profiling_mode') :\n    torch._C._jit_set_profiling_mode(False)\n\n@torch.jit.script\ndef jit_dropout_add(x, residual, prob, is_training):\n    # type: (Tensor, Tensor, float, bool) -> Tensor\n    out = F.dropout(x, p=prob, training=True)\n    out = residual + out\n    return out\n\n\nclass EncdecMultiheadAttn(nn.Module):\n    \"\"\"Multi-headed attention.\n\n    See \"Attention Is All You Need\" for more details.\n    \"\"\"\n    def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n        self.bias = bias\n        self.include_norm_add = include_norm_add\n        self.impl = impl\n        self.scaling = self.head_dim**-0.5\n\n        self.in_proj_weight_q    = Parameter(torch.Tensor(embed_dim, embed_dim))\n        self.in_proj_weight_kv   = Parameter(torch.Tensor(2*embed_dim, embed_dim))\n        self.out_proj_weight     = Parameter(torch.Tensor(embed_dim, embed_dim))\n        if self.bias:\n            assert impl != 'fast', \"ERROR! The Fast implementation does not support biases!\"\n            self.in_proj_bias_q  = Parameter(torch.Tensor(embed_dim))\n            self.in_proj_bias_kv = Parameter(torch.Tensor(2*embed_dim))\n            self.out_proj_bias   = Parameter(torch.Tensor(embed_dim))\n        else:\n            self.register_parameter('in_proj_bias_q', None)\n            self.register_parameter('in_proj_bias_kv', None)\n            self.in_proj_bias_q  = None\n            self.in_proj_bias_kv = None\n            self.out_proj_bias   = None\n        if self.include_norm_add:\n            if impl == 'fast':\n                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm_beta_weights  = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm               = None\n            else:\n                self.register_parameter('lyr_norm_gamma_weights', None)\n                self.register_parameter('lyr_norm_beta_weights', None)\n                self.lyr_nrm_gamma_weights = None\n                self.lyr_nrm_beta_weights  = None\n                self.lyr_nrm = FusedLayerNorm(embed_dim)\n        self.reset_parameters()\n\n        if self.include_norm_add:\n            if   impl == 'fast'    : self.attn_func = fast_encdec_attn_norm_add_func\n            elif impl == 'default' : self.attn_func = encdec_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n        else:\n            if   impl == 'fast'    : self.attn_func = fast_encdec_attn_func\n            elif impl == 'default' : self.attn_func = encdec_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n\n    def reset_parameters(self):\n        nn.init.xavier_uniform_(self.in_proj_weight_q)\n        # in_proj_weight_kv has shape [2 * hidden, hidden] but it should be\n        # initialized like a [hidden, hidden] matrix.\n        # sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)\n        # therefore xavier_uniform gain should be set to sqrt(1.5).\n        nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))\n        nn.init.xavier_uniform_(self.out_proj_weight)\n        if self.bias:\n            nn.init.constant_(self.in_proj_bias_q, 0.)\n            nn.init.constant_(self.in_proj_bias_kv, 0.)\n            nn.init.constant_(self.out_proj_bias, 0.)\n        if self.include_norm_add:\n            if self.impl == 'fast' :\n                nn.init.ones_(self.lyr_nrm_gamma_weights)\n                nn.init.zeros_(self.lyr_nrm_beta_weights)\n            else:\n                self.lyr_nrm.reset_parameters()\n\n    def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):\n        \"\"\"Input shape: Time x Batch x Channel\n\n        Self-attention can be implemented by passing in the same arguments for\n        query, key and value. Future timesteps can be masked with the\n        `mask_future_timesteps` argument. Padding elements can be excluded from\n        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:\n        batch x src_len, where padding elements are indicated by 1s.\n        \"\"\"\n\n        if key_padding_mask is not None:\n            assert (attn_mask is None), \"ERROR attn_mask and key_padding_mask should not be both defined!\"\n            mask = key_padding_mask\n        elif attn_mask is not None:\n            mask = attn_mask\n        else:\n            mask = None\n\n        if self.include_norm_add:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,\n                                         self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)\n            else:\n                lyr_nrm_results = self.lyr_nrm(query)\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,\n                                         self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,\n                                         mask, self.dropout)\n                if is_training:\n                    outputs = jit_dropout_add(outputs, query, self.dropout, is_training)\n                else:\n                    outputs = outputs + query\n        else:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)\n            else:\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,\n                                         self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,\n                                         self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,\n                                         mask, self.dropout)\n\n        return outputs,None\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/encdec_multihead_attn_func.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\nclass EncdecAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,\n                input_weights_q, input_weights_kv, output_weights,\n                input_biases_q, input_biases_kv, output_biases,\n                mask, dropout_prob):\n        use_biases_t   = torch.tensor([input_biases_q is not None])\n        heads_t        = torch.tensor([heads])\n        scale_t        = torch.tensor([scale])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        head_dim       = inputs_q.size(2) // heads\n\n        # Input Linear GEMM Q\n        # input1: (activations) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim (1024), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        if use_biases_t[0]:\n            input_lin_q_results = torch.addmm(input_biases_q,\n                                              inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                                              input_weights_q.transpose(0,1),\n                                              beta=1., alpha=1.)\n        else:\n            input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))\n        input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))\n        # Input Linear GEMM KV\n        # input1: (activations) [seql_k, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_k, seqs, embed_dim*2]\n        # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)\n        if use_biases_t[0]:\n            input_lin_kv_results = torch.addmm(input_biases_kv,\n                                               inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),\n                                               input_weights_kv.transpose(0,1),\n                                               beta=1., alpha=1.)\n        else:\n            input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))\n        input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))\n\n        # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]\n        # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]\n        queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads, head_dim)\n        input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)\n        keys    = input_lin_kv_results[:,:,0,:]\n        values  = input_lin_kv_results[:,:,1,:]\n\n        # Matmul1 Batched GEMMs\n        # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification\n        # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of \n        # a separate elementwise operation.\n        # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)\n        # Input2: (Keys)    [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # output:           [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))\n        matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])\n\n        if mask is not None:\n            # Self Attention Time Mask\n            if use_time_mask:\n                assert (len(mask.size()) == 2), \"Timing mask is not 2D!\"\n                assert (mask.size(0) == mask.size(1)), \"Sequence length should match!\"\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))\n            # Key Padding Mask\n            else:\n                batches,seql_q,seql_k = matmul1_results.size()\n                seqs = int(batches / heads)\n                matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))\n                matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)\n\n        softmax_results = F.softmax(matmul1_results, dim=-1)\n\n        # Dropout - is not executed for inference\n        if is_training:\n            dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))\n        else:\n            dropout_results = softmax_results\n            dropout_mask    = null_tensor\n\n        # Matmul2 Batched GEMMs\n        # The output tensor specification is needed here to specify the non-standard output.\n        # Given that pytorch cannot currently perform autograd with an output tensor specified,\n        # this requires a backward pass specified.\n        # Input1: from_softmax [seqs*heads, seql_q, seql_k]\n        # Input2: (values)     [seql_v, seqs*heads, head_dim] transpose(0,1)\n        # Output:              [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)\n        matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)\n        matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)\n        matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))\n\n        # Output Linear GEMM\n        # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        if use_biases_t[0]:\n            outputs = torch.addmm(output_biases,\n                                  matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                                  output_weights.transpose(0,1),\n                                  beta=1., alpha=1.)\n        else:\n            outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))\n        outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))\n\n        ctx.save_for_backward(use_biases_t,                             \\\n                              heads_t,                                  \\\n                              scale_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n    \n    @staticmethod\n    def backward(ctx, output_grads):\n        use_biases_t,                                                   \\\n        heads_t,                                                        \\\n        scale_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        inputs_q,                                                       \\\n        inputs_kv,                                                      \\\n        input_weights_q,                                                \\\n        input_weights_kv,                                               \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t          = ctx.saved_tensors\n\n        head_dim                = inputs_q.size(2) // heads_t[0]\n\n        # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]\n        # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]\n        queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads_t[0], head_dim)\n        input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)\n        keys    = input_lin_kv_results[:,:,0,:]\n        values  = input_lin_kv_results[:,:,1,:]\n\n        # Slice out k,v from one big set of gradients entering the input linear's bprop  (should only impact meta data, no copies!)\n        # The gradients are identical in size to the Input Linear outputs.\n        # The tensor is declared before hand to properly slice out query, key, and value grads.\n        input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)\n        queries_grads              = torch.empty_like(queries)\n        keys_grads                 = input_lin_kv_results_grads[:,:,0,:]\n        values_grads               = input_lin_kv_results_grads[:,:,1,:]\n\n        # Output Linear GEMM - DGRAD\n        # Input1: (data grads)  [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)\n        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))\n        # Output Linear GEMM - WGRAD\n        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)\n        # Input2: (activations) [seql_q*seqs, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )\n        output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),\n                                       matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))\n        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)\n\n        if use_biases_t[0]:\n            output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)\n        else:\n            output_bias_grads = None\n\n        # Matmul2 - DGRAD1\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))\n        # Matmul2 - DGRAD2\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        values_grads   = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))\n\n        # Mask and Scaling for Dropout (not a publically documented op)\n        dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))\n\n        # Softmax Grad (not a publically documented op)\n        softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)\n\n        # Matmul1 - DGRAD1\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] \n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )\n        queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),\n                                      out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n        # Matmul1 - DGRAD2\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)\n        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )\n        keys_grads    = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),\n                                      out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n\n        # Input Q Linear GEMM - DGRAD\n        # input1: (data grads) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)    [embed_dim (1024), embed_dim (1024)] \n        # output:              [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        queries_grads  = queries_grads.transpose(0,1).view(inputs_q.size(0)*inputs_q.size(1), heads_t[0]*head_dim)\n        input_q_grads = torch.mm(queries_grads, input_weights_q)\n        input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))\n        # Input KV Linear GEMM - DGRAD\n        # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]\n        # input2: (weights)    [embed_dim*2 (2048), embed_dim (1024)] \n        # output:              [seql_k, seqs, embed_dim]\n        # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)\n        input_lin_kv_results_grads = input_lin_kv_results_grads.view(inputs_kv.size(0)*inputs_kv.size(1), heads_t[0]*2*head_dim)\n        input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)\n        input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2))\n        # Input Q Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_q*seqs, embed_dim(1024)]\n        # input2: (activations) [seql_q*seqs, embed_dim(1024)] \n        # output:               [embed_dim, embed_dim]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)\n        input_weight_q_grads = torch.mm(queries_grads.transpose(0,1), inputs_q.view(inputs_q.size(0)*inputs_q.size(1), inputs_q.size(2)))\n        # Input KV Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_k*seqs, 2*embed_dim(2048)]\n        # input2: (activations) [seql_k*seqs, embed_dim(1024)] \n        # output:               [2*embed_dim, embed_dim]\n        # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)\n        input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))\n\n        if use_biases_t[0]:\n            input_bias_grads_q = torch.sum(queries_grads, 0)\n            input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)\n        else:\n            input_bias_grads_q = None\n            input_bias_grads_kv = None\n\n        return None, None, None, None,                                            \\\n               input_q_grads, input_kv_grads,                                     \\\n               input_weight_q_grads, input_weight_kv_grads, output_weight_grads,  \\\n               input_bias_grads_q, input_bias_grads_kv, output_bias_grads,        \\\n               None, None\n\nencdec_attn_func = EncdecAttnFunc.apply\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py",
    "content": "import torch\nimport fast_encdec_multihead_attn\n\n\nclass FastEncdecAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        softmax_results,                                                \\\n        dropout_results,                                                \\\n        dropout_mask,                                                   \\\n        matmul2_results,                                                \\\n        outputs =                                                       \\\n            fast_encdec_multihead_attn.forward(                         \\\n                              use_mask,                                 \\\n                              use_time_mask,                            \\\n                              is_training,                              \\\n                              heads,                                    \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              pad_mask if use_mask else null_tensor,    \\\n                              dropout_prob)\n\n        ctx.save_for_backward(heads_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        inputs_q,                                                       \\\n        inputs_kv,                                                      \\\n        input_weights_q,                                                \\\n        input_weights_kv,                                               \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t      = ctx.saved_tensors\n\n        input_q_grads,                                                  \\\n        input_kv_grads,                                                 \\\n        input_weight_q_grads,                                           \\\n        input_weight_kv_grads,                                          \\\n        output_weight_grads =                                           \\\n            fast_encdec_multihead_attn.backward(                        \\\n                              heads_t[0],                               \\\n                              output_grads,                             \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t[0])\n\n        return None, None, None, input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads, None, None\n\nfast_encdec_attn_func = FastEncdecAttnFunc.apply\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py",
    "content": "# Copyright (c) 2017-present, Facebook, Inc.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the LICENSE file in\n# the root directory of this source tree. An additional grant of patent rights\n# can be found in the PATENTS file in the same directory.\n\nimport torch\nimport fast_encdec_multihead_attn_norm_add\n\n\nclass FastEncdecAttnNormAddFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        softmax_results,                                                \\\n        dropout_results,                                                \\\n        dropout_mask,                                                   \\\n        matmul2_results,                                                \\\n        dropout_add_mask,                                               \\\n        outputs =                                                       \\\n            fast_encdec_multihead_attn_norm_add.forward(                \\\n                              use_mask,                                 \\\n                              use_time_mask,                            \\\n                              is_training,                              \\\n                              heads,                                    \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              pad_mask if use_mask else null_tensor,    \\\n                              dropout_prob)\n\n        ctx.save_for_backward(heads_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_q_results,                                            \\\n        input_lin_kv_results,                                           \\\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        inputs_q,                                                       \\\n        inputs_kv,                                                      \\\n        lyr_nrm_gamma_weights,                                          \\\n        lyr_nrm_beta_weights,                                           \\\n        input_weights_q,                                                \\\n        input_weights_kv,                                               \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_add_mask,                                               \\\n        dropout_prob_t         = ctx.saved_tensors\n\n        input_q_grads,                                                  \\\n        input_kv_grads,                                                 \\\n        lyr_nrm_gamma_grads,                                            \\\n        lyr_nrm_beta_grads,                                             \\\n        input_weight_q_grads,                                           \\\n        input_weight_kv_grads,                                          \\\n        output_weight_grads    =                                        \\\n            fast_encdec_multihead_attn_norm_add.backward(               \\\n                              heads_t[0],                               \\\n                              output_grads,                             \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_q_results,                      \\\n                              input_lin_kv_results,                     \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs_q,                                 \\\n                              inputs_kv,                                \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights_q,                          \\\n                              input_weights_kv,                         \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t[0])\n\n        #import pdb; pdb.set_trace()\n        return None, None, None,                                        \\\n               input_q_grads,                                           \\\n               input_kv_grads,                                          \\\n               lyr_nrm_gamma_grads,                                     \\\n               lyr_nrm_beta_grads,                                      \\\n               input_weight_q_grads,                                    \\\n               input_weight_kv_grads,                                   \\\n               output_weight_grads,                                     \\\n               None, None\n\nfast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py",
    "content": "import torch\nimport fast_self_multihead_attn\nimport fast_self_multihead_attn_bias\nimport fast_self_multihead_attn_bias_additive_mask\n\nclass FastSelfAttnFunc(torch.autograd.Function) :\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, pad_mask, mask_additive, dropout_prob):\n        use_biases_t   = torch.tensor([input_biases is not None])\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n        mask_additive_t= torch.tensor([mask_additive])\n\n        if use_biases_t[0]:\n            if not mask_additive:\n                input_lin_results,                                              \\\n                softmax_results,                                                \\\n                dropout_results,                                                \\\n                dropout_mask,                                                   \\\n                matmul2_results,                                                \\\n                outputs =                                                       \\\n                    fast_self_multihead_attn_bias.forward(                           \\\n                                      use_mask,                                 \\\n                                      use_time_mask,                            \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      input_biases,                           \\\n                                      output_biases,                           \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n                ctx.save_for_backward(use_biases_t,                                  \\\n                              heads_t,                          \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              null_tensor,                          \\\n                              null_tensor,                          \\\n                              mask_additive_t,                          \\\n                              input_lin_results,                        \\\n                              inputs,                                   \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n            else:\n                input_lin_results,                                              \\\n                bmm1_results,                                                \\\n                dropout_results,                                                \\\n                dropout_mask,                                                   \\\n                matmul2_results,                                                \\\n                outputs =                                                       \\\n                    fast_self_multihead_attn_bias_additive_mask.forward(                           \\\n                                      use_mask,                                 \\\n                                      use_time_mask,                            \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      input_biases,                           \\\n                                      output_biases,                           \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n                ctx.save_for_backward(use_biases_t,                                  \\\n                                      heads_t,                          \\\n                                      matmul2_results,                          \\\n                                      dropout_results,                          \\\n                                      null_tensor,                          \\\n                                      bmm1_results,                          \\\n                                      pad_mask,                          \\\n                                      mask_additive_t,                          \\\n                                      input_lin_results,                        \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      dropout_mask,                             \\\n                                      dropout_prob_t)\n\n\n        else:\n            input_lin_results,                                              \\\n            softmax_results,                                                \\\n            dropout_results,                                                \\\n            dropout_mask,                                                   \\\n            matmul2_results,                                                \\\n            outputs =                                                       \\\n                fast_self_multihead_attn.forward(                           \\\n                                  use_mask,                                 \\\n                                  use_time_mask,                            \\\n                                  is_training,                              \\\n                                  heads,                                    \\\n                                  inputs,                                   \\\n                                  input_weights,                            \\\n                                  output_weights,                           \\\n                                  pad_mask if use_mask else null_tensor,    \\\n                                  dropout_prob)\n            ctx.save_for_backward(use_biases_t,                                  \\\n                          heads_t,                          \\\n                          matmul2_results,                          \\\n                          dropout_results,                          \\\n                          softmax_results,                          \\\n                          null_tensor,                          \\\n                          null_tensor,                          \\\n                          mask_additive_t,                          \\\n                          input_lin_results,                        \\\n                          inputs,                                   \\\n                          input_weights,                            \\\n                          output_weights,                           \\\n                          dropout_mask,                             \\\n                          dropout_prob_t)\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        use_biases_t,                                                        \\\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        bmm1_results,                                                \\\n        pad_mask,                                                \\\n        mask_additive_t,                                                \\\n        input_lin_results,                                              \\\n        inputs,                                                         \\\n        input_weights,                                                  \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t      = ctx.saved_tensors\n\n        if use_biases_t[0]:\n            if not mask_additive_t[0]:\n                input_grads,                                                    \\\n                input_weight_grads,                                             \\\n                output_weight_grads,                                           \\\n                input_bias_grads,                                                   \\\n                output_bias_grads =                                                    \\\n                    fast_self_multihead_attn_bias.backward(                          \\\n                                      heads_t[0],                               \\\n                                      output_grads,                             \\\n                                      matmul2_results,                          \\\n                                      dropout_results,                          \\\n                                      softmax_results,                          \\\n                                      input_lin_results,                        \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      dropout_mask,                             \\\n                                      dropout_prob_t[0])\n\n            else:\n                input_grads,                                                    \\\n                input_weight_grads,                                             \\\n                output_weight_grads,                                           \\\n                input_bias_grads,                                                   \\\n                output_bias_grads =                                                    \\\n                    fast_self_multihead_attn_bias_additive_mask.backward(                          \\\n                                      heads_t[0],                               \\\n                                      output_grads,                             \\\n                                      matmul2_results,                          \\\n                                      dropout_results,                          \\\n                                      bmm1_results,                          \\\n                                      pad_mask,                          \\\n                                      input_lin_results,                        \\\n                                      inputs,                                   \\\n                                      input_weights,                            \\\n                                      output_weights,                           \\\n                                      dropout_mask,                             \\\n                                      dropout_prob_t[0])\n                    \n        else:\n            input_bias_grads = None                                                    \n            output_bias_grads = None\n            input_grads,                                                    \\\n            input_weight_grads,                                             \\\n            output_weight_grads =                                           \\\n                fast_self_multihead_attn.backward(                          \\\n                                  heads_t[0],                               \\\n                                  output_grads,                             \\\n                                  matmul2_results,                          \\\n                                  dropout_results,                          \\\n                                  softmax_results,                          \\\n                                  input_lin_results,                        \\\n                                  inputs,                                   \\\n                                  input_weights,                            \\\n                                  output_weights,                           \\\n                                  dropout_mask,                             \\\n                                  dropout_prob_t[0])\n        return None, None, None, input_grads, input_weight_grads, output_weight_grads,input_bias_grads, output_bias_grads, None, None, None\n\nfast_self_attn_func = FastSelfAttnFunc.apply\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py",
    "content": "import torch\nimport fast_self_multihead_attn_norm_add\n\n\nclass FastSelfAttnNormAddFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        input_lin_results,                                              \\\n        softmax_results,                                                \\\n        dropout_results,                                                \\\n        dropout_mask,                                                   \\\n        matmul2_results,                                                \\\n        dropout_add_mask,                                               \\\n        outputs =                                                       \\\n             fast_self_multihead_attn_norm_add.forward(                 \\\n                              use_mask,                                 \\\n                              use_time_mask,                            \\\n                              is_training,                              \\\n                              heads,                                    \\\n                              inputs,                                   \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              pad_mask if use_mask else null_tensor,    \\\n                              dropout_prob)\n\n        ctx.save_for_backward(heads_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_results,                        \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs,                                   \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        heads_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_results,                                              \\\n        lyr_nrm_results,                                                \\\n        lyr_nrm_mean,                                                   \\\n        lyr_nrm_invvar,                                                 \\\n        inputs,                                                         \\\n        lyr_nrm_gamma_weights,                                          \\\n        lyr_nrm_beta_weights,                                           \\\n        input_weights,                                                  \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_add_mask,                                               \\\n        dropout_prob_t          = ctx.saved_tensors\n\n        input_grads,                                                    \\\n        lyr_nrm_gamma_grads,                                            \\\n        lyr_nrm_beta_grads,                                             \\\n        input_weight_grads,                                             \\\n        output_weight_grads    =                                        \\\n            fast_self_multihead_attn_norm_add.backward(                 \\\n                              heads_t[0],                               \\\n                              output_grads,                             \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_results,                        \\\n                              lyr_nrm_results,                          \\\n                              lyr_nrm_mean,                             \\\n                              lyr_nrm_invvar,                           \\\n                              inputs,                                   \\\n                              lyr_nrm_gamma_weights,                    \\\n                              lyr_nrm_beta_weights,                     \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_add_mask,                         \\\n                              dropout_prob_t[0])\n\n        return None, None, None,                                        \\\n               input_grads,                                             \\\n               lyr_nrm_gamma_grads,                                     \\\n               lyr_nrm_beta_grads,                                      \\\n               input_weight_grads,                                      \\\n               output_weight_grads,                                     \\\n               None, None\n\nfast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/mask_softmax_dropout_func.py",
    "content": "import torch\nimport fast_mask_softmax_dropout\nimport fast_additive_mask_softmax_dropout\n\n\nclass MaskSoftmaxDropout(torch.autograd.Function) :\n    @staticmethod\n    def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):\n        heads_t        = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        use_mask       = (pad_mask is not None)\n        use_mask_t     = torch.tensor([use_mask])\n        mask_additive_t     = torch.tensor([mask_additive])\n\n        if mask_additive:\n            dropout_results,                                                \\\n            dropout_mask,                                                   \\\n            softmax_results =                                                \\\n                    fast_additive_mask_softmax_dropout.forward(                           \\\n                                      use_mask,                                 \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n        else:\n            dropout_results,                                                \\\n            dropout_mask,                                                   \\\n            softmax_results =                                                \\\n                    fast_mask_softmax_dropout.forward(                           \\\n                                      use_mask,                                 \\\n                                      is_training,                              \\\n                                      heads,                                    \\\n                                      inputs,                                   \\\n                                      pad_mask if use_mask else null_tensor,    \\\n                                      dropout_prob)\n        \n        ctx.save_for_backward(\n                              use_mask_t,                                    \\\n                              heads_t,                                 \\\n                              softmax_results,                          \\\n                              dropout_mask,                             \\\n                              pad_mask if use_mask else null_tensor,        \\\n                              mask_additive_t,        \\\n                              dropout_prob_t)\n\n        return dropout_results.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        use_mask_t, \\\n        heads_t,   \\\n        softmax_results,                                                \\\n        dropout_mask,                                              \\\n        pad_mask,                                                   \\\n        mask_additive_t,                                                   \\\n        dropout_prob_t      = ctx.saved_tensors\n\n        if mask_additive_t[0]:\n            input_grads =                                                    \\\n                fast_additive_mask_softmax_dropout.backward(                          \\\n                                  use_mask_t[0],                             \\\n                                  heads_t[0],                             \\\n                                  output_grads,                             \\\n                                  softmax_results,                          \\\n                                  dropout_mask,                             \\\n                                  dropout_prob_t[0])\n        else:\n            input_grads =                                                    \\\n                fast_mask_softmax_dropout.backward(                          \\\n                                  use_mask_t[0],                             \\\n                                  heads_t[0],                             \\\n                                  output_grads,                             \\\n                                  softmax_results,                          \\\n                                  dropout_mask,                             \\\n                                  pad_mask,                             \\\n                                  dropout_prob_t[0])\n        return None, None, input_grads, None, None, None\n\nfast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/self_multihead_attn.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom .self_multihead_attn_func               import self_attn_func\nfrom .fast_self_multihead_attn_func          import fast_self_attn_func\nfrom .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func\nfrom apex.normalization.fused_layer_norm     import FusedLayerNorm\n\nif hasattr(torch._C, '_jit_set_profiling_executor') :\n    torch._C._jit_set_profiling_executor(False)\nif hasattr(torch._C, '_jit_set_profiling_mode') :\n    torch._C._jit_set_profiling_mode(False)\n\n@torch.jit.script\ndef jit_dropout_add(x, residual, prob, is_training):\n    # type: (Tensor, Tensor, float, bool) -> Tensor\n    out = F.dropout(x, p=prob, training=True)\n    out = residual + out\n    return out\n\n\nclass SelfMultiheadAttn(nn.Module):\n    \"\"\"Multi-headed attention.\n\n    See \"Attention Is All You Need\" for more details.\n    \"\"\"\n    def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast', separate_qkv_params=False, mask_additive=False):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n        self.bias = bias\n        self.include_norm_add = include_norm_add\n        self.impl = impl\n        self.scaling = self.head_dim**-0.5\n        self.separate_qkv_params = separate_qkv_params\n        self.mask_additive = mask_additive\n        if mask_additive:\n            assert self.include_norm_add == False, \"additive mask not supported with layer norm\"\n            assert impl == 'default' or (impl == 'fast' and bias), \"additive mask not supported for fast mode without bias\"\n        if separate_qkv_params:\n            self.q_weight  = Parameter(torch.Tensor(embed_dim, embed_dim))\n            self.k_weight  = Parameter(torch.Tensor(embed_dim, embed_dim))\n            self.v_weight  = Parameter(torch.Tensor(embed_dim, embed_dim))\n        else:\n            self.in_proj_weight  = Parameter(torch.Tensor(3*embed_dim, embed_dim))\n        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))\n        if self.bias:\n            if separate_qkv_params:\n                self.q_bias  = Parameter(torch.Tensor(embed_dim))\n                self.k_bias  = Parameter(torch.Tensor(embed_dim))\n                self.v_bias  = Parameter(torch.Tensor(embed_dim))\n            else:\n                self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))\n            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))\n        else:\n            if separate_qkv_params:\n                self.register_parameter('q_bias', None)\n                self.register_parameter('k_bias', None)\n                self.register_parameter('v_bias', None)\n                self.q_bias = None\n                self.k_bias = None\n                self.v_bias = None\n            else:\n                self.register_parameter('in_proj_bias', None)\n                self.in_proj_bias = None\n            self.register_parameter('out_proj_bias', None)\n            self.out_proj_bias = None\n        if self.include_norm_add:\n            if impl == 'fast':\n                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm_beta_weights  = Parameter(torch.Tensor(embed_dim))\n                self.lyr_nrm               = None\n            else:\n                self.register_parameter('lyr_norm_gamma_weights', None)\n                self.register_parameter('lyr_norm_beta_weights', None)\n                self.lyr_nrm_gamma_weights = None\n                self.lyr_nrm_beta_weights  = None\n                self.lyr_nrm = FusedLayerNorm(embed_dim)\n        self.reset_parameters()\n\n        if self.include_norm_add:\n            if   impl == 'fast'    : self.attn_func = fast_self_attn_norm_add_func\n            elif impl == 'default' : self.attn_func = self_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n        else:\n            if   impl == 'fast'    : self.attn_func = fast_self_attn_func\n            elif impl == 'default' : self.attn_func = self_attn_func\n            else :                   assert False, \"Unsupported impl: {} !\".format(impl)\n\n    def reset_parameters(self):\n        if self.separate_qkv_params:\n            nn.init.xavier_uniform_(self.q_weight)\n            nn.init.xavier_uniform_(self.k_weight)\n            nn.init.xavier_uniform_(self.v_weight)\n        else:\n            # in_proj_weight has shape [3 * hidden, hidden] but it should be\n            # initialized like a [hidden, hidden] matrix.\n            # sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)\n            # therefore xavier_uniform gain should be set to sqrt(2).\n            nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))\n        nn.init.xavier_uniform_(self.out_proj_weight)\n        if self.bias:\n            if self.separate_qkv_params:\n                nn.init.constant_(self.q_bias, 0.)\n                nn.init.constant_(self.k_bias, 0.)\n                nn.init.constant_(self.v_bias, 0.)\n            else:\n                nn.init.constant_(self.in_proj_bias, 0.)\n            nn.init.constant_(self.out_proj_bias, 0.)\n        if self.include_norm_add:\n            if self.impl == 'fast':\n                nn.init.ones_(self.lyr_nrm_gamma_weights)\n                nn.init.zeros_(self.lyr_nrm_beta_weights)\n            else:\n                self.lyr_nrm.reset_parameters()\n\n    def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):\n        \"\"\"Input shape: Time x Batch x Channel\n\n        Self-attention can be implemented by passing in the same arguments for\n        query, key and value. Future timesteps can be masked with the\n        `mask_future_timesteps` argument. Padding elements can be excluded from\n        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:\n        batch x src_len, where padding elements are indicated by 1s.\n        \"\"\"\n        if self.separate_qkv_params:\n            input_weights = torch.cat([self.q_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.k_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.v_weight.view(self.num_heads,1,self.head_dim,self.embed_dim)], dim=1).reshape(3*self.embed_dim,self.embed_dim).contiguous()\n        else: \n            input_weights = self.in_proj_weight\n        if self.bias:\n            if self.separate_qkv_params:\n                input_bias = torch.cat([self.q_bias.view(self.num_heads,1,self.head_dim), self.k_bias.view(self.num_heads,1,self.head_dim), self.v_bias.view(self.num_heads,1,self.head_dim)],dim=1).reshape(3*self.embed_dim).contiguous()\n            else:\n                input_bias = self.in_proj_bias\n        else:\n            input_bias=None        \n        if key_padding_mask is not None:\n            assert (attn_mask is None), \"ERROR attn_mask and key_padding_mask should not be both defined!\"\n            mask = key_padding_mask\n        elif attn_mask is not None:\n            assert self.mask_additive == False, \"additive mask not supported for time mask\"\n            mask = attn_mask\n        else:\n            mask = None\n\n        if self.include_norm_add:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,\n                                         self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,\n                                         input_weights, self.out_proj_weight, mask, self.dropout)\n            else:\n                lyr_nrm_results = self.lyr_nrm(query)\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,\n                                         input_weights, self.out_proj_weight,\n                                         input_bias, self.out_proj_bias,\n                                         mask, self.dropout)\n                if is_training:\n                    outputs = jit_dropout_add(outputs, query, self.dropout, is_training)\n                else:\n                    outputs = outputs + query\n        else:\n            if self.impl == 'fast':\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,\n                                         input_weights, self.out_proj_weight, input_bias, self.out_proj_bias, mask, self.mask_additive, self.dropout)\n            else:\n                outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,\n                                         input_weights, self.out_proj_weight,\n                                         input_bias, self.out_proj_bias,\n                                         mask, self.mask_additive, self.dropout)\n\n        return outputs,None\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/multihead_attn/self_multihead_attn_func.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nclass SelfAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, use_time_mask, is_training, heads, scale, inputs,\n                input_weights, output_weights,\n                input_biases, output_biases,\n                mask, is_additive_mask, dropout_prob):\n        use_biases_t   = torch.tensor([input_biases is not None])\n        heads_t        = torch.tensor([heads])\n        scale_t        = torch.tensor([scale])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor    = torch.tensor([])\n        head_dim       = inputs.size(2) // heads\n\n        # Input Linear GEMM\n        # input1: (activations) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_q, seqs, embed_dim*3]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)\n        if use_biases_t[0]:\n            input_lin_results = torch.addmm(input_biases,\n                                            inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                                            input_weights.transpose(0,1),\n                                            beta=1., alpha=1.)\n        else:\n            input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))\n        input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))\n\n        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]\n        # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]\n        input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads, 3, head_dim)\n        queries = input_lin_results[:,:,0,:]\n        keys    = input_lin_results[:,:,1,:]\n        values  = input_lin_results[:,:,2,:]\n\n        # Matmul1 Batched GEMMs\n        # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification\n        # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of\n        # a separate elementwise operation.\n        # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)\n        # Input2: (Keys)    [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # output:           [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))\n        matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])\n\n        if mask is not None:\n            # Self Attention Time Mask\n            if use_time_mask:\n                assert (len(mask.size()) == 2), \"Timing mask is not 2D!\"\n                assert (mask.size(0) == mask.size(1)), \"Sequence length should match!\"\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))\n            # Key Padding Mask\n            else:\n                batches,seql_q,seql_k = matmul1_results.size()\n                seqs = int(batches / heads)\n                matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)\n                if is_additive_mask:\n                    matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)\n                else:\n                    mask = mask.to(torch.bool)\n                    matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))\n                matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)\n\n        softmax_results = F.softmax(matmul1_results, dim=-1)\n\n        # Dropout - is not executed for inference\n        if is_training:\n            dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))\n        else:\n            dropout_results = softmax_results\n            dropout_mask    = null_tensor\n\n        # Matmul2 Batched GEMMs\n        # The output tensor specification is needed here to specify the non-standard output.\n        # Given that pytorch cannot currently perform autograd with an output tensor specified,\n        # this requires a backward pass specified.\n        # Input1: from_softmax [seqs*heads, seql_q, seql_k]\n        # Input2: (values)     [seql_v, seqs*heads, head_dim] transpose(0,1)\n        # Output:              [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)\n        matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)\n        matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)\n        matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))\n\n        # Output Linear GEMM\n        # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        if use_biases_t[0]:\n            outputs = torch.addmm(output_biases,\n                                  matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                                  output_weights.transpose(0,1),\n                                  beta=1., alpha=1.)\n        else:\n            outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))\n        outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))\n\n        ctx.save_for_backward(use_biases_t,                             \\\n                              heads_t,                                  \\\n                              scale_t,                                  \\\n                              matmul2_results,                          \\\n                              dropout_results,                          \\\n                              softmax_results,                          \\\n                              input_lin_results,                        \\\n                              inputs,                                   \\\n                              input_weights,                            \\\n                              output_weights,                           \\\n                              dropout_mask,                             \\\n                              dropout_prob_t)\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        use_biases_t,                                                   \\\n        heads_t,                                                        \\\n        scale_t,                                                        \\\n        matmul2_results,                                                \\\n        dropout_results,                                                \\\n        softmax_results,                                                \\\n        input_lin_results,                                              \\\n        inputs,                                                         \\\n        input_weights,                                                  \\\n        output_weights,                                                 \\\n        dropout_mask,                                                   \\\n        dropout_prob_t          = ctx.saved_tensors\n\n        head_dim                = inputs.size(2) // heads_t[0]\n\n        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]\n        # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]\n        input_lin_results       = input_lin_results.view(inputs.size(0), inputs.size(1)*heads_t[0], 3, head_dim)\n        queries                 = input_lin_results[:,:,0,:]\n        keys                    = input_lin_results[:,:,1,:]\n        values                  = input_lin_results[:,:,2,:]\n\n        # Slice out q,k,v from one big set of gradients entering the input linear's bprop  (should only impact meta data, no copies!)\n        # The gradients are identical in size to the Input Linear outputs.\n        # The tensor is declared before hand to properly slice out query, key, and value grads.\n        input_lin_results_grads = torch.empty_like(input_lin_results)\n        queries_grads           = input_lin_results_grads[:,:,0,:]\n        keys_grads              = input_lin_results_grads[:,:,1,:]\n        values_grads            = input_lin_results_grads[:,:,2,:]\n\n        # Output Linear GEMM - DGRAD\n        # Input1: (data grads)  [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)\n        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))\n        # Output Linear GEMM - WGRAD\n        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)\n        # Input2: (activations) [seql_q*seqs, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )\n        output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),\n                                       matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))\n        output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)\n\n        if use_biases_t[0]:\n            output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)\n        else:\n            output_bias_grads = None\n\n        # Matmul2 - DGRAD1\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))\n        # Matmul2 - DGRAD2\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        values_grads   = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))\n\n        # Mask and Scaling for Dropout (not a publically documented op)\n        dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))\n\n        # Softmax Grad (not a publically documented op)\n        softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)\n\n        # Matmul1 - DGRAD1\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] \n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )\n        queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),\n                                      out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n        # Matmul1 - DGRAD2\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)\n        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )\n        keys_grads    = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),\n                                      out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])\n\n        # Input Linear GEMM - DGRAD\n        # input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]\n        # input2: (weights)    [embed_dim*3 (3072), embed_dim (1024)] \n        # output:              [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        input_lin_results_grads = input_lin_results_grads.view(inputs.size(0)*inputs.size(1), heads_t[0]*3*head_dim)\n        input_grads = torch.mm(input_lin_results_grads, input_weights)\n        input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))\n        # Input Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_q*seqs, 3*embed_dim(3072)]\n        # input2: (activations) [seql_q*seqs, embed_dim(1024)] \n        # output:               [3*embed_dim, embed_dim]\n        # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)\n        input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))\n\n        if use_biases_t[0]:\n            input_bias_grads = torch.sum(input_lin_results_grads, 0)\n        else:\n            input_bias_grads = None\n\n        return None, None, None, None,                   \\\n               input_grads,                              \\\n               input_weight_grads, output_weight_grads,  \\\n               input_bias_grads, output_bias_grads,      \\\n               None, None\n\nself_attn_func = SelfAttnFunc.apply\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/__init__.py",
    "content": "from .fp16_optimizer import FP16_Optimizer\nfrom .fused_adam import FusedAdam\nfrom .fused_lamb import FusedLAMB\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/distributed_fused_adam.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nimport torch.distributed.distributed_c10d as c10d\n\nclass DistributedFusedAdam(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n    \n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n    \n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        overlap_reductions(boolean, optional): whether to overlap reductions\n            with bprop (default: True)\n        step_supports_amp_scaling(boolean, optional): whether to use customized\n            gradient unscaling logic (default: True)\n        num_process_groups (integer, optional): number of process groups in\n            the app (default: 1)\n        current_process_group (object, optional): the process group to work on\n            (default: None)\n        process_group_id (integer, optional): process group id (default: 0)\n        process_group_size (integer, optional): size of process group\n            (default: 0)\n        clip_grad_norm (boolean, optional): whether to handle gradient clipping\n            (default: True)\n        model_parallel (boolean, optional): whether model parallelism is used\n            (default: False)\n\n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction=True, betas=(0.9, 0.999),\n                 eps=1e-8, eps_inside_sqrt=False,\n                 weight_decay=0., max_grad_norm=0.,\n                 amsgrad=False, flat_mt=False,\n                 overlap_reductions=True,\n                 compute_L2_grad_norm=False,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,\n                 dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,\n                 predivide=True, e5m2_allgather=False,\n                 do_not_flatten_model=False,\n                 step_supports_amp_scaling=True,\n                 num_process_groups=1,\n                 current_process_group=None,\n                 process_group_id=0,\n                 process_group_size=0,\n                 clip_grad_norm=True,\n                 model_parallel=False):\n        global fused_adam_cuda, distributed_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n        distributed_adam_cuda = importlib.import_module(\"distributed_adam_cuda\")\n        self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n\n        if amsgrad:\n            raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')\n\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(DistributedFusedAdam, self).__init__(params, defaults)\n\n        # Misc\n        self.eps_mode = 0 if eps_inside_sqrt else 1\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        self._has_overflow = False\n        self._step_supports_amp_scaling = step_supports_amp_scaling\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._num_chunks = dwu_num_chunks\n        self._predivide = predivide\n        self._e5m2_allgather = e5m2_allgather\n        self._do_not_flatten_model = do_not_flatten_model\n        self._compute_L2_grad_norm = compute_L2_grad_norm\n        self._L2_grad_norm = None\n        self._flat_mt = flat_mt\n        self._init_done = False\n        self._resume_from_checkpoint = False\n        self._step = 0\n\n        # Process group related\n        self._clip_grad_norm = clip_grad_norm\n        self._model_parallel = model_parallel\n        self._num_process_groups = num_process_groups\n        self._current_process_group = current_process_group if current_process_group is not None else c10d._get_default_group()\n        self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())\n        self._process_group_id = process_group_id\n        self._process_group_size = torch.cuda.device_count() if process_group_size <= 0 else process_group_size\n        self._world_size = self._process_group_size # world: the current process group\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._num_groups = self._world_size // self._group_size\n        self._global_rank = torch.distributed.get_rank()\n        self._world_rank = self._global_rank // self._num_process_groups\n        self._group_rank = self._world_rank % self._group_size\n        #print(\"world_size:\", self._world_size, \", group_size:\", self._group_size, \", num_groups:\", self._num_groups, \", global_rank:\", self._global_rank, \", world_rank:\", self._world_rank, \", group_rank:\", self._group_rank)\n        self._num_rs_pg = dwu_num_rs_pg\n        self._num_ar_pg = dwu_num_ar_pg\n        self._num_ag_pg = dwu_num_ag_pg\n\n        # Master weight, moment, gradient buffers\n        self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None\n\n    def _first_step_init(self):\n        p_offset = 0\n        p_i = 0\n        self._model_params = []\n        self._grads_info = []\n        self._grad_accs = []\n        self._group_properties = []\n        for group in self.param_groups:\n            self._param_group = group\n            prev = None\n            beta1, beta2 = group['betas']\n            bias_correction = 1 if group['bias_correction'] else 0\n            eps = group['eps']\n            weight_decay = group['weight_decay']\n            for p in group['params']:\n                # broadcast from rank 0 of current process group\n                torch.distributed.broadcast(p, src=self._available_ranks[0], group=self._current_process_group)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                # Multiple param groups support: \n                # store one hyperparam item per parameter tensor\n                self._group_properties.append((\n                    beta1,\n                    beta2,\n                    bias_correction,\n                    eps,\n                    weight_decay\n                    ))\n                p_grads_size = p.numel()\n                def wrapper(param, param_i, param_grads_size, param_offset):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n                wrapper(p, p_i, p_grads_size, p_offset)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._grads_info)\n        self._grads = []\n        if self._overlap_reductions:\n            self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._chunk_size = self._block_size // self._num_chunks\n        self._shard_size = self._chunk_size // self._group_size\n        #print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        #print(self._low_param_i)\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size\n        # initialize master weights, moments buffers if not loaded from checkpoint\n        if self._fp32_p is None:\n            self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        # FIXME: Rethink fp16 label since it's either uint8 or fp16\n        self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')\n\n        self._individual_flat_grads = []\n        for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):\n            self._individual_flat_grads.append(self._flat_grads[grads_info[\"param_offset\"]:grads_info[\"param_offset\"]+grads_info[\"param_grads_size\"]].view_as(p))\n\n        def _flat_split(p):\n            def __blockify(p):\n                return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            def __shardify(p):\n                return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            list_of_blocks = __blockify(self._flat_grads)\n            list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]\n            list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]\n            return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards\n        self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)\n        def _full_packed_split(p):\n            def __shardify(p):\n                return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]\n            def __blockify(p):\n                return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_mega_shards = __shardify(p)\n            list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]\n            list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]\n            return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks\n        self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)\n        def _packed_split(p):\n            def __packed_blockify(p):\n                packed_block_size = self._num_chunks*self._shard_size\n                return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]\n            def __packed_chunkify(p):\n                # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_blocks = __packed_blockify(p)\n            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]\n            return list_of_blocks, list_of_list_of_chunks\n        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)\n        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)\n        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)\n        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)\n        self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)\n\n        # This paragraph does two things:\n        # 1) Copy model parameters into master buffer\n        # 2) Create tensor lists for unpacking new parameter tensor after all-gather\n        self._packed_flat_to_model_params = []\n        self._contrib_tensor_list = []\n        self._contrib_group_properties = []\n        self._non_parallel_grads = []\n        for shard_id in range(self._group_size):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size\n                    flat_shard_end = flat_shard_start + self._shard_size\n                    for (p, grads_info, group_props) in zip(self._model_params, self._grads_info, self._group_properties):\n                        flat_grad_start = grads_info[\"param_offset\"]\n                        flat_grad_end = flat_grad_start + grads_info[\"param_grads_size\"]\n                        clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)\n                        clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)\n                        if clipped_start < clipped_end:\n                            grad_offset = clipped_start - flat_grad_start\n                            grad_length = clipped_end - clipped_start\n                            shard_offset = clipped_start - flat_shard_start\n                            model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]\n                            new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                            self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )\n                            if shard_id == self._group_rank:\n                                # copy model parameters into master buffer\n                                master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                #print(\"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s\" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))\n                                if not self._resume_from_checkpoint:\n                                    master_param_fragment.copy_(model_param_fragment)\n                                self._contrib_group_properties.append(group_props)\n                                self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, g, p_copy\n                                if self._model_parallel and hasattr(p, 'model_parallel') and not p.model_parallel:\n                                    self._non_parallel_grads.append(opti_state_g_fragment)\n\n        p, m, v, g, p_copy = list(zip(*self._contrib_tensor_list))\n        self._contrib_tensor_list = [p, m, v, g, p_copy]\n\n        math_type = self._fp32_p.dtype\n        beta1, beta2, bias_correction, epsilon, decay = list(zip(*self._contrib_group_properties))\n        self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')\n        self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')\n        self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')\n        self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')\n        self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')\n\n        p_in, p_out = zip(*self._packed_flat_to_model_params)\n        self._packed_flat_to_model_params = [p_in, p_out]\n\n        if self._num_groups > 1:\n            self._ar_pg = []\n            for i in range(self._num_process_groups):\n                # gather global ranks of all members of the current process group\n                ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]\n                for j in range(self._group_size):\n                    ar_idx = [j+k*self._group_size for k in range(self._num_groups)]\n                    ar_rank = [ranks[k] for k in ar_idx]\n                    #if self._global_rank in ar_rank:\n                    #    print(\"group for all reduce, ranks:\", ar_rank)\n                    for _ in range(self._num_ar_pg):\n                        grp = torch.distributed.new_group(ranks=ar_rank)\n                        if self._global_rank in ar_rank:\n                            self._ar_pg.append(grp)\n            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            for ar_pg in self._ar_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)\n\n        self._rs_pg, rs_ranks = [],[]\n        for i in range(self._num_process_groups):\n            ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]\n            for j in range(self._num_groups):\n                rs_idx = [j*self._group_size+k for k in range(self._group_size)]\n                rs_rank = [ranks[k] for k in rs_idx]\n                #if self._global_rank in rs_rank:\n                #    print(\"group for reduce scatter, ranks:\", rs_rank)\n                for _ in range(self._num_rs_pg):\n                    grp = torch.distributed.new_group(ranks=rs_rank)\n                    if self._global_rank in rs_rank:\n                        self._rs_pg.append(grp)\n                if self._compute_L2_grad_norm:\n                    l2_grad_norm_pg = torch.distributed.new_group(ranks=rs_rank)\n                    if self._global_rank in rs_rank:\n                        self._l2_grad_norm_pg = l2_grad_norm_pg\n                        torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)\n        self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]\n        for rs_pg in self._rs_pg:\n            torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)\n\n        if self._num_ag_pg == 0:\n            self._ag_pg = self._rs_pg\n            self._ag_st = self._rs_st\n            self._num_ag_pg = self._num_rs_pg\n        else:\n            self._ag_pg = []\n            for i in range(self._num_process_groups):\n                ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]\n                for j in range(self._num_groups):\n                    ag_rank = rs_ranks[j]\n                    #if self._global_rank in ag_rank:\n                    #    print(\"group for all gather, ranks:\", ag_rank)\n                    for _ in range(self._num_ag_pg):\n                        grp = torch.distributed.new_group(ranks=ag_rank)\n                        if self._global_rank in ag_rank:\n                            self._ag_pg.append(grp)\n            self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n            for ag_pg in self._ag_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)\n        self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None\n        self._completion_st = torch.cuda.Stream()\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n    def _init_everything(self):\n        if not self._init_done:\n            self._first_step_init()\n            self._init_done = True\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def _pipeline_block_reductions(self, block_id):\n        self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n\n        # Reduction within each node\n        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n        # The output format is the same as the fp32 master parameters\n        works = [None]*self._num_chunks\n        for chunk_id in range(self._num_chunks):\n            glob_chunk_id = block_id * self._num_chunks + chunk_id\n            rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]\n            rs_stream.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(rs_stream):\n                works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)\n\n        # Reduction across nodes for each rank\n        if self._num_groups > 1:\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                with torch.cuda.stream(ar_stream):\n                    works[chunk_id].wait()\n                    works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n        self._reductions_works[block_id] = works\n\n        # Optionally compute L2 grad norm\n        if self._compute_L2_grad_norm and block_id == 0:\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                for block_id in range(self._num_blocks):\n                    for chunk_id in range(self._num_chunks):\n                        self._reductions_works[block_id][chunk_id].wait()\n                # Since the packed format is contiguous after reductions, only one norm is needed\n                l2_grad_norm_sq = torch.empty([1], device='cuda')\n                l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2\n                torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)\n                # for model_parallel_rank=0, keep all gradients\n                # for the rest, subtract non_parallel gradients\n                if self._model_parallel and self._process_group_id: # non zero model_parallel_rank\n                    non_parallel_grad_norm_sq = torch.zeros([1], device='cuda')\n                    if len(self._non_parallel_grads): # non parallel grads exit\n                        non_parallel_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm,\n                                                                         self._overflow_buf,\n                                                                         [self._non_parallel_grads], False)[0]**2\n                    torch.distributed.all_reduce(non_parallel_grad_norm_sq, group=self._l2_grad_norm_pg)\n                    l2_grad_norm_sq = l2_grad_norm_sq - non_parallel_grad_norm_sq\n                self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()\n\n    def __launch_step_kernel(self):\n        # If self._clip_grad_norm is False, we assume gradient clipping already \n        # happened outside the optimizer and self._global_scale has already \n        # been set to the combined scale, i.e. it's no longer the current loss\n        # scale used by the loss scaler. \n        # For model parallelism cases in which we need to get global gradient \n        # norm via all-reduce outside the optimizer to do the clipping. \n        combined_scale = self._global_scale\n        if self._clip_grad_norm and self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        \n        self._step += 1\n        multi_tensor_applier(distributed_adam_cuda.multi_tensor_fused_adam,\n                self._overflow_buf,\n                self._contrib_tensor_list, # p, m, v, g, p_copy\n                self._contrib_beta1,\n                self._contrib_beta2,\n                self._contrib_bias_correction,\n                self._contrib_epsilon,\n                self._contrib_weight_decay,\n                self._param_group['lr'],\n                combined_scale,\n                self._step,\n                self.eps_mode)\n\n    def _pipeline_step(self):\n        # Call step kernel once per step\n        # Call all-gather once per step\n        with torch.cuda.stream(self._completion_st):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    self._reductions_works[block_id][chunk_id].wait()\n            self.__launch_step_kernel()\n            torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _flatten_grad_mt(self, scale):\n        if self._flat_mt and len(self._grads) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads)),\n                    scale)\n            self._grads = []\n\n    def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):\n        # handle overlapped reductions\n        if self._flat_mt:\n            self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )\n        else:\n            torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])\n        self._grads_generated[param_i]=True\n        if not self._last_step:\n            if self._overlap_reductions:\n                flush_block = self._get_flush_block()\n                while flush_block:\n                    block_id = flush_block[0] // self._block_size\n                    self._pipeline_block_reductions(block_id)\n                    flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def has_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Clears the overflow flag.\n        \"\"\"\n        has_overflow = self._has_overflow\n        self._has_overflow = False\n        return has_overflow\n\n    @property\n    def peek_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Does not clear overflow flag.\n        \"\"\"\n        return self._has_overflow\n\n    def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):\n        \"\"\"Strided check for overflow.\n        You can get status by calling has_overflow.\n        \"\"\"\n        if start >= 0 and start < end:\n            out_p = output_params[start:end]\n        else:\n            out_p = output_params\n        fused_adam_cuda.strided_check_finite(self._overflow_buf,\n                out_p,\n                stride,\n                1 if clear else 0)\n        self._has_overflow = False if self._overflow_buf.item() == 0 else True\n        return self._has_overflow\n\n    @property\n    def L2_grad_norm(self):\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n            return self._L2_grad_norm\n        else:\n            return None\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n        self._init_everything()\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, grad_generated in enumerate(self._grads_generated):\n                if not grad_generated:\n                    grad_info = self._grads_info[param_i]\n                    param_offset = grad_info[\"param_offset\"]\n                    param_size = grad_info[\"param_grads_size\"]\n                    self._flat_grads[param_offset:param_offset+param_size].zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            for block_id in range(self._num_blocks-1,-1,-1):\n                self._pipeline_block_reductions(block_id)\n\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def step(self, closure=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        self._pipeline_step()\n\n        with torch.cuda.stream(self._completion_st):\n            # Copy self._new_params to model params\n            multi_tensor_applier(\n                    fused_adam_cuda.maybe_cast_mt,\n                    self._overflow_buf,\n                    self._packed_flat_to_model_params)\n\n        torch.cuda.current_stream().wait_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        return loss\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        # save step, master weights and first/second moments\n        state_dict = {}\n        state_dict['step'] = self._step\n        state_dict['fp32_p'] = self._fp32_p\n        state_dict['fp32_m'] = self._fp32_m\n        state_dict['fp32_v'] = self._fp32_v\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``optimizer.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # restore step, master weights and first/second moments\n        self._step = state_dict['step']\n        self._fp32_p = state_dict['fp32_p'].to(device=\"cuda\")\n        self._fp32_m = state_dict['fp32_m'].to(device=\"cuda\")\n        self._fp32_v = state_dict['fp32_v'].to(device=\"cuda\")\n        self._resume_from_checkpoint = True\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/distributed_fused_adam_v2.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass DistributedFusedAdamV2(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        use_mt (boolean, optional): use multi tensor apply for lower launch\n            latency. (default: False)\n        overlap_reductions(boolean, optional): whether to overlap reductions\n            with bprop (default: True)\n        num_prestats (integer, optional): number of fp64 stats that will be\n            reduced during first fp16 gradient reduction block. \n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True,\n                 betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,\n                 weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,\n                 amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,\n                 compute_L2_grad_norm=False, distributed_weight_update=0,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,\n                 dwu_num_ag_pg=0, revert_method=1, flat_mt=False,\n                 dwu_num_chunks=4, predivide=True, e5m2_allgather=False,\n                 do_not_flatten_model=False):\n        global fused_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n\n        self._amp_scale_adjustment = amp_scale_adjustment\n\n        if use_mt:\n            raise RuntimeError('DistributedFusedAdam does not support use_mt.')\n        if amsgrad:\n            raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')\n\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(DistributedFusedAdamV2, self).__init__(params, defaults)\n        self.eps_mode = 0 if  eps_inside_sqrt else 1\n\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        self._has_overflow = False\n\n        assert (len(self.param_groups) == 1), \"More than one parameter group is not supported.\"\n\n        # Way to revert a step\n        # 3 -> undo kernel + double buffer (debug, print norm of difference)\n        # 2 -> double buffer fp32 parameters\n        # 1 -> undo kernel\n        self._revert_method = revert_method\n        if self._revert_method > 1:\n            print(\"revert_method -> double buffer fp32 parameters, will consume more memory\")\n\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._num_chunks = dwu_num_chunks\n        self._predivide = predivide\n        self._e5m2_allgather = e5m2_allgather\n        self._do_not_flatten_model = do_not_flatten_model\n        self._full_pipeline = full_pipeline\n        self._compute_L2_grad_norm = compute_L2_grad_norm\n        self._L2_grad_norm = None\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._world_size = torch.distributed.get_world_size()\n        self._num_groups = self._world_size // self._group_size\n        self._rank_in_group = torch.distributed.get_rank() % self._group_size\n\n        p_offset = 0\n        p_i = 0\n        self._param_state = None\n        self._model_params = []\n        self._grads_info = []\n        self._grad_accs = []\n        for group in self.param_groups:\n            self._param_group = group\n            prev = None\n            for p in group['params']:\n                torch.distributed.broadcast(p,0)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                state = self.state[p]\n                if len(state) == 0:\n                    state['step'] = 0\n                if self._param_state is None:\n                    self._param_state = state\n                p_grads_size = p.numel()\n                def wrapper(param, param_i, param_grads_size, param_offset):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n                wrapper(p, p_i, p_grads_size, p_offset)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._grads_info)\n        self._flat_mt = flat_mt\n        self._grads = []\n        if self._overlap_reductions:\n            self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._shard_size = self._block_size // self._group_size\n        self._chunk_size = self._shard_size // self._num_chunks\n        print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size,self._chunk_size))\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        print(self._low_param_i)\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._mega_shard_size = self._num_blocks * self._num_chunks * self._chunk_size\n        self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        # FIXME: Rethink fp16 label since it's either uint8 or fp16\n        self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')\n\n        self._individual_flat_grads = []\n        for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):\n            self._individual_flat_grads.append(self._flat_grads[grads_info[\"param_offset\"]:grads_info[\"param_offset\"]+grads_info[\"param_grads_size\"]].view_as(p))\n\n        def _flat_split(p):\n            def __blockify(p):\n                return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __shardify(p):\n                return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._group_size)]\n            list_of_blocks = __blockify(self._flat_grads)\n            list_of_list_of_shards = [__shardify(block) for block in list_of_blocks]\n            list_of_list_of_list_of_chunks = [[__chunkify(shard) for shard in shards] for shards in list_of_list_of_shards]\n            return list_of_blocks, list_of_list_of_shards, list_of_list_of_list_of_chunks\n        self._flat_grads_blocks, self._flat_grads_shards, self._flat_grads_chunks = _flat_split(self._flat_grads)\n        def _full_packed_split(p):\n            def __shardify(p):\n                return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]\n            def __blockify(p):\n                return [p[block_id*self._num_chunks*self._chunk_size:(block_id+1)*self._num_chunks*self._chunk_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            list_of_mega_shards = __shardify(p)\n            list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]\n            list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]\n            return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks\n        self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)\n        def _packed_split(p):\n            def __packed_blockify(p):\n                packed_block_size = self._num_chunks*self._chunk_size\n                return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]\n            def __packed_chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            list_of_blocks = __packed_blockify(p)\n            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]\n            return list_of_blocks, list_of_list_of_chunks\n        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)\n        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)\n        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)\n        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)\n        self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)\n\n        # current arrangement\n        # \n        # self._flat_grads\n        # self._flat_grads_blocks [x self._num_blocks, self._block_size]\n        # self._flat_grads_chunks [x self._num_chunks, self._chunk_size]\n        # self._flat_grads_shards [x self._group_size, self._shard_size]\n        #\n        # self._new_params\n        # self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]\n        # self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]\n        # self._new_params_mega_chunks [x self._num_chunks, self._shard_size]\n        #\n        # self._fp32_p\n        # self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]\n        # self._fp32_p_chunks [x self._num_chunks, self._shard_size]\n        # each chunk contains one shard\n        # same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g\n        #\n        # Usage:\n        # \n        # for chunk_id in range(self._num_chunks):\n        #   works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)\n        #\n        # ----------------------------------------------------------------------------------------\n        #\n        # new arrangement\n        #\n        # NB! New equations for self._shard_size and self._chunk_size\n        #\n        # self._flat_grads\n        # self._flat_grads_blocks [x self._num_blocks, self._block_size]\n        # self._flat_grads_shards [x self._group_size, self._shard_size]\n        # self._flat_grads_chunks [x self._num_chunks, self._chunk_size]\n        #\n        # self._new_params\n        # self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]\n        # self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]\n        # self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]\n        #\n        # self._fp32_p\n        # self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]\n        # self._fp32_p_chunks [x self._num_chunks, self._chunk_size]\n        # same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g\n        #\n        # Usage:\n        #\n        # work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)\n        # for chunk_id in range(self._num_chunks):\n        #   work.wait()\n        #   works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)\n        # or\n        # work.wait()\n        # works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)\n        #\n\n        # This paragraph does two things:\n        # 1) Copy model parameters into master buffer\n        # 2) Create tensor lists for unpacking new parameter tensor after all-gather\n        self._packed_flat_to_model_params = []\n        for shard_id in range(self._group_size):\n            for block_id in range(self._num_blocks):\n                flat_shard_start = (block_id * self._group_size + shard_id) * self._shard_size\n                flat_shard_end = flat_shard_start + self._shard_size\n                for p, grads_info in zip(self._model_params, self._grads_info):\n                    flat_grad_start = grads_info[\"param_offset\"]\n                    flat_grad_end = flat_grad_start + grads_info[\"param_grads_size\"]\n                    clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)\n                    clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)\n                    if clipped_start < clipped_end:\n                        grad_offset = clipped_start - flat_grad_start\n                        grad_length = clipped_end - clipped_start\n                        shard_offset = clipped_start - flat_shard_start\n                        model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]\n                        new_param_packed_fragment = self._new_params_mega_blocks[shard_id][block_id][shard_offset:shard_offset+grad_length]\n                        self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )\n                        if shard_id == self._rank_in_group:\n                            # copy model parameters into master buffer\n                            master_param_fragment = self._fp32_p_blocks[block_id][shard_offset:shard_offset+grad_length]\n                            print(\"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s\" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))\n                            master_param_fragment.copy_(model_param_fragment)\n\n        p_in, p_out = zip(*self._packed_flat_to_model_params)\n        self._packed_flat_to_model_params = [p_in, p_out]\n\n        self._distributed_weight_update = distributed_weight_update # Is this still needed?\n        self._num_rs_pg = dwu_num_rs_pg\n        self._num_ar_pg = dwu_num_ar_pg\n        self._num_ag_pg = dwu_num_ag_pg\n        if self._num_groups > 1:\n            self._ar_pg = []\n            for dev_i in range(self._group_size):\n                ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]\n                for i in range(self._num_ar_pg):\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if torch.distributed.get_rank() in ranks:\n                        self._ar_pg.append(grp)\n            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            for ar_pg in self._ar_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)\n        rs_ranks = []\n        for group_i in range(self._num_groups):\n            rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])\n        self._rs_pg = []\n        for group_i in range(self._num_groups):\n            ranks = rs_ranks[group_i]\n            for i in range(self._num_rs_pg):\n                grp = torch.distributed.new_group(ranks=ranks)\n                if torch.distributed.get_rank() in ranks:\n                    self._rs_pg.append(grp)\n            if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:\n                self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)\n                torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)\n        self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]\n        for rs_pg in self._rs_pg:\n            torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)\n        if self._num_ag_pg == 0:\n            self._ag_pg = self._rs_pg\n            self._ag_st = self._rs_st\n            self._num_ag_pg = self._num_rs_pg\n        else:\n            self._ag_pg = []\n            for group_i in range(self._num_groups):\n                ranks = rs_ranks[group_i]\n                for i in range(self._num_ag_pg):\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if torch.distributed.get_rank() in ranks:\n                        self._ag_pg.append(grp)\n            self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n            for ag_pg in self._ag_pg:\n                torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)\n        self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None\n        self._completion_st = torch.cuda.Stream()\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def _pipeline_block_reductions(self, block_id):\n        self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n\n        # Reduction within each node\n        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n        # The output format is the same as the fp32 master parameters\n        works = [None]*self._num_chunks\n        rs_stream = self._rs_st[block_id%self._num_rs_pg]\n        rs_stream.wait_stream(torch.cuda.current_stream())\n        with torch.cuda.stream(rs_stream):\n            rs_work = torch.distributed.reduce_scatter(self._fp16_g_blocks[block_id],self._flat_grads_shards[block_id],group=self._rs_pg[block_id%self._num_rs_pg],async_op=True,no_copy=True)\n            for chunk_id in range(self._num_chunks):\n                works[chunk_id] = rs_work\n\n        # Reduction across nodes for each rank\n        if self._num_groups > 1:\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                with torch.cuda.stream(ar_stream):\n                    rs_work.wait()\n                    works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n        self._reductions_works[block_id] = works\n\n        # Optionally compute L2 grad norm\n        if self._compute_L2_grad_norm and block_id == 0:\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                for block_id in range(self._num_blocks):\n                    for chunk_id in range(self._num_chunks):\n                        self._reductions_works[block_id][chunk_id].wait()\n                # Since the packed format is contiguous after reductions, only one norm is needed\n                l2_grad_norm_sq = torch.empty([1], device='cuda')\n                l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2\n                torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)\n                self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()\n\n    def __launch_step_kernel(self, p, p_copy, m, v, g):\n        combined_scale = self._global_scale\n        if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        bias_correction = 1 if self._param_group['bias_correction'] else 0\n        beta1, beta2 = self._param_group['betas']\n        fused_adam_cuda.reversible_adam(\n                p, p_copy, m, v, g,\n                self._param_group['lr'],\n                beta1,\n                beta2,\n                self._param_group['eps'],\n                combined_scale,\n                self._param_state['step']+1,\n                self.eps_mode,\n                bias_correction,\n                self._param_group['weight_decay'])\n\n    def _pipeline_block_step(self, block_id):\n        # Call step kernel once per block\n        ag_stream = self._ag_st[block_id%self._num_ag_pg]\n        with torch.cuda.stream(ag_stream):\n            for chunk_id in range(self._num_chunks):\n                self._reductions_works[block_id][chunk_id].wait()\n            self.__launch_step_kernel(\n                self._fp32_p_blocks[block_id],\n                self._fp16_p_blocks[block_id],\n                self._fp32_m_blocks[block_id],\n                self._fp32_v_blocks[block_id],\n                self._fp16_g_blocks[block_id])\n        # Call all-gather once per step.\n        # FIXME: Determine which is faster, one all-gather per block or a single all-gather at end\n        if block_id == 0:\n            for other_ag_stream in self._ag_st:\n                self._completion_st.wait_stream(other_ag_stream)\n            with torch.cuda.stream(self._completion_st):\n                torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _pipeline_step(self):\n        # Call step kernel once per step\n        # Call all-gather once per step\n        with torch.cuda.stream(self._completion_st):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    self._reductions_works[block_id][chunk_id].wait()\n            self.__launch_step_kernel(\n                self._fp32_p,\n                self._fp16_p,\n                self._fp32_m,\n                self._fp32_v,\n                self._fp16_g)\n            torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _flatten_grad_mt(self, scale):\n        if self._flat_mt and len(self._grads) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads)),\n                    scale)\n            self._grads = []\n\n    def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):\n        # handle overlapped reductions\n        if self._flat_mt:\n            self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )\n        else:\n            torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])\n        self._grads_generated[param_i]=True\n        if not self._last_step:\n            if self._overlap_reductions:\n                flush_block = self._get_flush_block()\n                while flush_block:\n                    block_id = flush_block[0] // self._block_size\n                    self._pipeline_block_reductions(block_id)\n                    if self._full_pipeline:\n                        self._pipeline_block_step(block_id)\n                    flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def has_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Clears the overflow flag.\n        \"\"\"\n        has_overflow = self._has_overflow\n        self._has_overflow = False\n        return has_overflow\n\n    @property\n    def peek_overflow(self):\n        \"\"\"Check if overflows were detected by any call to step(...) method.\n        Does not clear overflow flag.\n        \"\"\"\n        return self._has_overflow\n\n    def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):\n        \"\"\"Strided check for overflow.\n        You can get status by calling has_overflow.\n        \"\"\"\n        if start >= 0 and start < end:\n            out_p = output_params[start:end]\n        else:\n            out_p = output_params\n        fused_adam_cuda.strided_check_finite(self._overflow_buf,\n                out_p,\n                stride,\n                1 if clear else 0)\n        self._has_overflow = False if self._overflow_buf.item() == 0 else True\n        return self._has_overflow\n\n    @property\n    def L2_grad_norm(self):\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n            return self._L2_grad_norm\n        else:\n            return None\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, grad_generated in enumerate(self._grads_generated):\n                if not grad_generated:\n                    grad_info = self._grads_info[param_i]\n                    param_offset = grad_info[\"param_offset\"]\n                    param_size = grad_info[\"param_grads_size\"]\n                    self._flat_grads[param_offset:param_offset+param_size].zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            for block_id in range(self._num_blocks-1,-1,-1):\n                self._pipeline_block_reductions(block_id)\n\n        if self._compute_L2_grad_norm:\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def revert_step(self):\n        \"\"\"Revert effect of previously calling partial_step.\n        \"\"\"\n        # Call undo kernel once per step\n        combined_scale = self._global_scale\n        if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        bias_correction = 1 if self._param_group['bias_correction'] else 0\n        beta1, beta2 = self._param_group['betas']\n        fused_adam_cuda.maybe_adam_undo(\n                    torch.empty([0]),\n                    self._fp32_p,\n                    self._fp32_m,\n                    self._fp32_v,\n                    self._fp16_g,\n                    self._param_group['lr'],\n                    beta1,\n                    beta2,\n                    self._param_group['eps'],\n                    combined_scale,\n                    self._param_state['step']+1,\n                    self.eps_mode,\n                    bias_correction,\n                    self._param_group['weight_decay'])\n\n    def step(self, closure=None, skip_overflow_check=False):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if self._last_step or not self._overlap_reductions or not self._full_pipeline:\n            self._pipeline_step()\n\n        with torch.cuda.stream(self._completion_st):\n            # Check for overflow\n            # Store state for loss scaler calculation\n            has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)\n            if has_overflow:\n                self.revert_step()\n            else:\n                # Copy self._new_params to model params\n                for p in self._model_params: self.state[p]['step'] += 1\n                multi_tensor_applier(\n                        fused_adam_cuda.maybe_cast_mt,\n                        self._overflow_buf,\n                        self._packed_flat_to_model_params)\n\n        torch.cuda.current_stream().wait_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        return loss\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/distributed_fused_adam_v3.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass DistributedFusedAdamV3(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        use_mt (boolean, optional): use multi tensor apply for lower launch\n            latency. (default: False)\n        overlap_reductions(boolean, optional): whether to overlap reductions\n            with bprop (default: True)\n        num_prestats (integer, optional): number of fp64 stats that will be\n            reduced during first fp16 gradient reduction block. \n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True,\n                 betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,\n                 weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,\n                 amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,\n                 compute_L2_grad_norm=False, distributed_weight_update=0,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,\n                 dwu_num_ag_pg=0, revert_method=1, flat_mt=False,\n                 dwu_num_chunks=4, predivide=True, e5m2_allgather=False,\n                 do_not_flatten_model=False):\n        global fused_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n\n        self._amp_scale_adjustment = amp_scale_adjustment\n\n        if use_mt:\n            raise RuntimeError('DistributedFusedAdam does not support use_mt.')\n        if amsgrad:\n            raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')\n\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(DistributedFusedAdamV3, self).__init__(params, defaults)\n        self.eps_mode = 0 if  eps_inside_sqrt else 1\n\n        self._overflow_buf = torch.cuda.IntTensor([0])\n\n        assert (len(self.param_groups) == 1), \"More than one parameter group is not supported.\"\n\n        # Way to revert a step\n        # 3 -> undo kernel + double buffer (debug, print norm of difference)\n        # 2 -> double buffer fp32 parameters\n        # 1 -> undo kernel\n        self._revert_method = revert_method\n        if self._revert_method > 1:\n            print(\"revert_method -> double buffer fp32 parameters, will consume more memory\")\n\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._predivide = predivide\n        self._e5m2_allgather = e5m2_allgather\n        self._do_not_flatten_model = do_not_flatten_model\n        self._full_pipeline = full_pipeline\n        self._L2_grad_norm = None\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._world_size = torch.distributed.get_world_size()\n        self._num_groups = self._world_size // self._group_size\n        self._rank_in_group = torch.distributed.get_rank() % self._group_size\n\n        p_offset = 0\n        p_i = 0\n        self._param_state = None\n        self._model_params = []\n        self._grads_info = []\n        self._grad_accs = []\n        for group in self.param_groups:\n            self._param_group = group\n            prev = None\n            for p in group['params']:\n                torch.distributed.broadcast(p,0)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                state = self.state[p]\n                if len(state) == 0:\n                    state['step'] = 0\n                if self._param_state is None:\n                    self._param_state = state\n                p_grads_size = p.numel()\n                def wrapper(param, param_i, param_grads_size, param_offset):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n                wrapper(p, p_i, p_grads_size, p_offset)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._grads_info)\n        self._flat_mt = flat_mt\n        self._grads = []\n        self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._shard_size = self._total_param_size // self._group_size\n        print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        print(self._low_param_i)\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._flat_params = torch.zeros_like(self._flat_grads)\n\n        def _flat_split(flat):\n            def __flat_blockify(flat):\n                return [flat[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __flat_shardify(flat):\n                return [flat[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            return __flat_blockify(flat), __flat_shardify(flat)\n        self._flat_grads_blocks, self._flat_grads_shards = _flat_split(self._flat_grads)\n        self._flat_params_blocks, self._flat_params_shards = _flat_split(self._flat_params)\n\n        # master params\n        self._fp32_p = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_m = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')\n        self._fp32_v = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')\n\n        # copy model params to flat_params and set_ model params to flat_params.\n        self._individual_flat_grads = []\n        with torch.no_grad():\n            for p, grads_info in zip(self._model_params, self._grads_info):\n                start = grads_info[\"param_offset\"]\n                end = start + grads_info[\"param_grads_size\"]\n                flat_p = self._flat_params[start:end].view_as(p)\n                flat_p.copy_(p)\n                p.set_(flat_p)\n                flat_grad = self._flat_grads[start:end]\n                self._individual_flat_grads.append(flat_grad)\n        self._fp32_p.copy_(self._flat_params_shards[self._rank_in_group].float())\n\n        self._dwu_st = torch.cuda.Stream()\n        self._l2_grad_norm_st = torch.cuda.Stream()\n        for group_i in range(self._num_groups):\n            ranks = [group_i*self._group_size+local_rank for local_rank in range(self._group_size)]\n            pg = torch.distributed.new_group(ranks=ranks)\n            if torch.distributed.get_rank() in ranks:\n                self._ag_pg = pg\n                torch.distributed.all_reduce(self._overflow_buf, group=self._ag_pg)\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n    @property\n    def has_overflow(self):\n        return True if not self.L2_grad_norm is None and not math.isfinite(self.L2_grad_norm) else False\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def __launch_step_kernel(self, p, p_copy, m, v, g):\n        combined_scale = self._global_scale\n        if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):\n            combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)\n            combined_scale = self._global_scale / min(1, combined_scale)\n        bias_correction = 1 if self._param_group['bias_correction'] else 0\n        beta1, beta2 = self._param_group['betas']\n        fused_adam_cuda.reversible_adam(\n                p, p_copy, m, v, g,\n                self._param_group['lr'],\n                beta1,\n                beta2,\n                self._param_group['eps'],\n                combined_scale,\n                self._param_state['step']+1,\n                self.eps_mode,\n                bias_correction,\n                self._param_group['weight_decay'])\n\n    def _flatten_grad_mt(self, scale):\n        if self._flat_mt and len(self._grads) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads)),\n                    scale)\n            self._grads = []\n\n    def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):\n        # handle overlapped reductions\n        if self._flat_mt:\n            self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )\n        else:\n            torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])\n        self._grads_generated[param_i]=True\n        if not self._last_step and self._overlap_reductions:\n            flush_block = self._get_flush_block()\n            while flush_block:\n                block_id = flush_block[0] // self._block_size\n                self._dwu_st.wait_stream(torch.cuda.current_stream())\n                with torch.cuda.stream(self._dwu_st):\n                    self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n                    torch.distributed.all_reduce(self._flat_grads_blocks[block_id])\n                if block_id == 0:\n                    self._l2_grad_norm_st.wait_stream(self._dwu_st)\n                    with torch.cuda.stream(self._l2_grad_norm_st):\n                        self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()\n                flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def L2_grad_norm(self):\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n        return self._L2_grad_norm\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, flat_grad in enumerate(self._individual_flat_grads):\n                if not self._grads_generated[param_i]:\n                    flat_grad.zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            self._dwu_st.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(self._dwu_st):\n                self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)\n                torch.distributed.all_reduce(self._flat_grads)\n            self._l2_grad_norm_st.wait_stream(self._dwu_st)\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def step(self, closure=None, skip_overflow_check=False):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        with torch.cuda.stream(self._dwu_st):\n            self.__launch_step_kernel(\n                self._fp32_p,\n                self._flat_params_shards[self._rank_in_group],\n                self._fp32_m,\n                self._fp32_v,\n                self._flat_grads_shards[self._rank_in_group])\n            torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)\n            for p in self._model_params: self.state[p]['step'] += 1\n\n        torch.cuda.current_stream().wait_stream(self._dwu_st)\n\n        return loss\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/distributed_fused_lamb.py",
    "content": "import math\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nimport torch.distributed.distributed_c10d as c10d\n\nclass DistributedFusedLAMB(torch.optim.Optimizer):\n\n    \"\"\"Implements LAMB algorithm.\n    \n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n    \n    This version of fused LAMB implements 2 fusions.\n      \n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n    \n    :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n        \n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n    \n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n        \n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n    \n    In general, ``opt_level=\"O1\"`` is recommended.\n    \n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n    \n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n        use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0\n            weight decay parameter (default: False)\n        step_supports_amp_scaling(boolean, optional): whether to use customized\n            gradient unscaling logic (default: True)\n    \n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    class AtomicCounter(object):\n        def __init__(self):\n            self.value = 0\n            self.order = []\n            import threading\n            self._lock = threading.Lock()\n\n        def add(self, idx):\n            with self._lock:\n                self.value += 1\n                self.order.append(idx)\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True, grad_averaging=True,\n                 betas=(0.9, 0.999), eps=1e-8, \n                 weight_decay=0., max_grad_norm=0., \n                 adam_w_mode=True, use_nvlamb=False,\n                 step_supports_amp_scaling=True, overlap_reductions=True,\n                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,\n                 dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, \n                 e5m2_allgather=False, verbose=False, clip_after_ar=True):\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging,\n                        max_grad_norm=max_grad_norm)\n\n        super(DistributedFusedLAMB, self).__init__(params, defaults)\n\n        global fused_adam_cuda, distributed_lamb_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n        distributed_lamb_cuda = importlib.import_module(\"distributed_lamb_cuda\")\n\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        self._has_overflow = False\n        self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term\n        self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights\n        import amp_C\n        self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n\n        self._grad_averaging = grad_averaging\n        self._adam_w_mode = 1 if adam_w_mode else 0\n        self._use_nvlamb = use_nvlamb\n        self._step_supports_amp_scaling = step_supports_amp_scaling\n        self._is_accumulation_step = False\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._num_chunks = dwu_num_chunks\n        self._e5m2_allgather = e5m2_allgather\n        self._verbose = verbose\n        self._clip_after_ar = clip_after_ar\n        self._L2_grad_norm = None\n        \n        self._current_process_group = c10d._get_default_group()\n        self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._world_size = torch.distributed.get_world_size()\n        self._num_groups = self._world_size // self._group_size\n        self._rank_in_group = torch.distributed.get_rank() % self._group_size\n\n        self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')\n\n        self._resume_from_checkpoint = False\n        self._step = torch.cuda.IntTensor([0])\n\n        # Master weight, moment, gradient buffers\n        self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None\n\n        import inspect\n        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), \"This version of c10d does not support no_copy option\"\n\n        self._num_rs_pg = dwu_num_rs_pg\n        self._num_ar_pg = dwu_num_ar_pg\n        self._num_ag_pg = dwu_num_ag_pg\n        if self._num_groups > 1:\n            self._ar_pg = []\n            for dev_i in range(self._group_size):\n                ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]\n                for i in range(self._num_ar_pg):\n                    if self._verbose:\n                        print(f\"creating new group {i}: {ranks}\")\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:\n                        if self._verbose:\n                            print(f\"group {i}: init barrier (device: {torch.cuda.current_device()})\")\n                        torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])\n                    if self._verbose:\n                        print(f\"created new group {i}\")\n\n                    if torch.distributed.get_rank() in ranks:\n                        self._ar_pg.append(grp)\n            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            #for ar_pg in self._ar_pg:\n            #    torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)\n        rs_ranks = []\n        for group_i in range(self._num_groups):\n            rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])\n        self._rs_pg = []\n        for group_i in range(self._num_groups):\n            ranks = rs_ranks[group_i]\n            for i in range(self._num_rs_pg):\n                grp = torch.distributed.new_group(ranks=ranks)\n                if torch.distributed.get_rank() in ranks:\n                    self._rs_pg.append(grp)\n            l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)\n            if torch.distributed.get_rank() in ranks:\n                self._l2_grad_norm_pg = l2_grad_norm_pg\n                #torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)\n        self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]\n        #for rs_pg in self._rs_pg:\n        #    torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)\n        if self._num_ag_pg == 0:\n            self._ag_pg = self._rs_pg\n            self._ag_st = self._rs_st\n            self._num_ag_pg = self._num_rs_pg\n        else:\n            self._ag_pg = []\n            for group_i in range(self._num_groups):\n                ranks = rs_ranks[group_i]\n                for i in range(self._num_ag_pg):\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if torch.distributed.get_rank() in ranks:\n                        self._ag_pg.append(grp)\n            self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n            #for ag_pg in self._ag_pg:\n            #    torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)\n        self._l2_grad_norm_st = torch.cuda.Stream()\n        self._completion_st = torch.cuda.Stream()\n        self._step.record_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        self._one = torch.cuda.IntTensor([1])\n\n        self._first_step = True\n        self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False\n        self._param_order = self.AtomicCounter()\n\n    def _lazy_init_stage1(self):\n        if self._lazy_init_stage1_done: return\n\n        p_offset = 0\n        p_i = 0\n        self._model_params = []\n        self._grad_accs = []\n        self._group_properties = []\n        for group in self.param_groups:\n            prev = None\n            beta1, beta2 = group['betas']\n            beta3 = 1.0 - beta1 if self._grad_averaging else 1.0\n            bias_correction = 1 if group['bias_correction'] else 0\n            eps = group['eps']\n            weight_decay = group['weight_decay']\n            for p in group['params']:\n                torch.distributed.broadcast(p, 0)\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                self._group_properties.append((\n                    weight_decay,\n                    bias_correction,\n                    beta1,\n                    beta2,\n                    beta3,\n                    eps\n                    ))\n                p_grads_size = p.numel()\n                def wrapper(param, param_i):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    def allreduce_hook(*unused):\n                        if self._first_step:\n                            # first time\n                            self._param_order.add(param_i)\n                        else:\n                            idx = self._param_order.order.index(param_i)\n                            self._do_overlapped_reduction(idx, param)\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n                wrapper(p, p_i)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        self._grads_generated = [False]*len(self._model_params)\n        self._grads_fp16, self._grads_fp32 = [], []\n        if self._overlap_reductions:\n            self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size\n        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size\n        self._block_size = self._total_param_size // self._num_blocks\n        self._chunk_size = self._block_size // self._num_chunks\n        self._shard_size = self._chunk_size // self._group_size\n        #print(\"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d\" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')\n        self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size\n        # initialize master weights, moments buffers if not loaded from checkpoint\n        if self._fp32_p is None:\n            self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n            self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')\n        # FIXME: Rethink fp16 label since it's either uint8 or fp16\n        self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')\n        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')\n\n        def _flat_split(p):\n            def __blockify(p):\n                return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n            def __shardify(p):\n                return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]\n            list_of_blocks = __blockify(self._flat_grads)\n            list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]\n            list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]\n            return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards\n        self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)\n        def _full_packed_split(p):\n            def __shardify(p):\n                return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]\n            def __blockify(p):\n                return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]\n            def __chunkify(p):\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_mega_shards = __shardify(p)\n            list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]\n            list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]\n            return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks\n        self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)\n        def _packed_split(p):\n            def __packed_blockify(p):\n                packed_block_size = self._num_chunks*self._shard_size\n                return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]\n            def __packed_chunkify(p):\n                # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size\n                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]\n            list_of_blocks = __packed_blockify(p)\n            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]\n            return list_of_blocks, list_of_list_of_chunks\n        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)\n        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)\n        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)\n        self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)\n        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)\n        self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)\n\n        self._lazy_init_stage1_done = True\n\n    def _lazy_init_stage2(self):\n        if self._lazy_init_stage2_done: return\n\n        self._param_order.order.reverse()\n\n        # re-order model_params, grad_accs, group_properties lists\n        self._model_params = [self._model_params[i] for i in self._param_order.order]\n        self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]\n        self._group_properties = [self._group_properties[i] for i in self._param_order.order]\n\n        # re-collect grads info (size, offset) after ordering\n        prev = None\n        p_offset = 0\n        self._grads_info = []\n        self._individual_flat_grads = []\n        for i, p in enumerate(self._model_params):\n            p_grads_size = p.numel()\n            self._grads_info.append({\"param_grads_size\":p_grads_size, \"param_offset\":p_offset})\n            self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))\n            # for the first iteration\n            self._do_overlapped_reduction(i, p)\n            p_offset += p_grads_size\n            # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n            # RNN is one example of consecutive parameters:\n            # (weight_ih, weight_hh, bias_ih, bias_hh)\n            if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):\n                p_offset = ((p_offset + 63) // 64) * 64\n            prev = p\n\n        self._low_param_i = [0]*self._num_blocks\n        for block_id in range(self._num_blocks-1,-1,-1):\n            p_i = len(self._grads_info)-1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id*self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        #print(\"self._low_param_i\", self._low_param_i)\n\n        # This paragraph does two things:\n        # 1) Copy model parameters into master buffer\n        # 2) Create tensor lists for unpacking new parameter tensor after all-gather\n        self._packed_flat_to_model_params_fp16 = []\n        self._packed_flat_to_model_params_fp32 = []\n        self._model_params_num = len(self._model_params)\n        self._contrib_tensor_list = []\n        self._contrib_min_param_i, self._contrib_max_param_i = -1, -1\n        self._contrib_update_frag_for_norm = []\n        self._contrib_model_param_for_norm_fp16 = []\n        self._contrib_model_param_for_norm_fp32 = []\n        self._contrib_model_param_for_norm_is_fp16 = []\n        self._model_param_is_contrib = []\n        self._contrib_group_properties = []\n        for shard_id in range(self._group_size):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size\n                    flat_shard_end = flat_shard_start + self._shard_size\n                    for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)):\n                        flat_grad_start = grads_info[\"param_offset\"]\n                        flat_grad_end = flat_grad_start + grads_info[\"param_grads_size\"]\n                        clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)\n                        clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)\n                        if clipped_start < clipped_end:\n                            grad_offset = clipped_start - flat_grad_start\n                            grad_length = clipped_end - clipped_start\n                            shard_offset = clipped_start - flat_shard_start\n                            model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]\n                            new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                            if model_param_fragment.dtype == torch.float16:\n                                self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )\n                            else:\n                                self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )\n                            if shard_id == self._rank_in_group:\n                                self._model_param_is_contrib.append(param_i)\n                                # copy model parameters into master buffer\n                                master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]\n                                #print(\"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s\" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))\n                                if not self._resume_from_checkpoint:\n                                    master_param_fragment.copy_(model_param_fragment)\n                                self._contrib_group_properties.append(group_props)\n                                self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy\n                                self._contrib_update_frag_for_norm.append(opti_state_u_fragment)\n                                if p.dtype == torch.float16:\n                                    self._contrib_model_param_for_norm_fp16.append(p)\n                                else:\n                                    self._contrib_model_param_for_norm_fp32.append(p)\n                                self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False)\n                                if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i\n                                self._contrib_max_param_i = param_i\n        self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)\n        if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None\n        if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None\n        self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')\n        self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')\n        self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')\n\n        p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))\n        self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]\n        self._contrib_update_weights_tensor_list = [u, p, p_copy]\n\n        math_type = self._fp32_u.dtype\n        decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties))\n        self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')\n        self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')\n        self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda')\n        self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')\n        self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')\n        self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')\n\n        self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None\n        self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None\n\n        self._lazy_init_stage2_done = True\n\n        self.complete_reductions()\n        self._first_step = False\n\n    def set_is_accumulation_step(self, is_accumulation_step):\n        self._is_accumulation_step = is_accumulation_step\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n        \n    def _get_flush_block(self):\n        flush_block = []\n        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:\n                contiguous_idx -= 1\n\n            if contiguous_idx < num_grads and self._grads_info[contiguous_idx][\"param_offset\"] <= (self._current_block-1)*self._block_size:\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block+1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def _pipeline_block_reductions(self, block_id):\n        if self._clip_after_ar:\n            self._flatten_grad_mt(1.0/self._world_size)\n\n            # Reduction within each node\n            # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n            # The output format is the same as the fp32 master parameters\n            works = [None]*self._num_chunks\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]\n                rs_stream.wait_stream(torch.cuda.current_stream())\n                with torch.cuda.stream(rs_stream):\n                    works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)\n\n            # Reduction across nodes for each rank\n            if self._num_groups > 1:\n                for chunk_id in range(self._num_chunks):\n                    glob_chunk_id = block_id * self._num_chunks + chunk_id\n                    ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                    with torch.cuda.stream(ar_stream):\n                        works[chunk_id].wait()\n                        works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n            self._reductions_works[block_id] = works\n\n            # Compute L2 grad norm\n            if block_id == 0:\n                with torch.cuda.stream(self._l2_grad_norm_st):\n                    for block_id in range(self._num_blocks):\n                        for chunk_id in range(self._num_chunks):\n                            self._reductions_works[block_id][chunk_id].wait()\n                    # Since the packed format is contiguous after reductions, only one norm is needed\n                    l2_grad_norm_sq = torch.empty([1], device='cuda')\n                    l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2\n                    torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)\n                    self._L2_grad_norm = l2_grad_norm_sq.sqrt()\n        else:\n            # Copy model grads to flat grads buffer\n            self._flatten_grad_mt(1.0)\n\n            # Compute L2 grad norm\n            self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n            # Apply clipping & pre-reduction scaling on grads\n            loss_scale = self.global_scale\n            max_grad_norm = loss_scale*self.defaults['max_grad_norm']\n            coeff = max_grad_norm /(1e-6+self.L2_grad_norm)\n            coeff = (coeff>1) * self._one + (coeff<=1) * coeff\n            tmp = torch.cat(((self._one), (coeff)))\n            index = (coeff+1>coeff).int()\n            scale = tmp.index_select(0, index).half()/self._world_size\n            self._flat_grads.mul_(scale)\n\n            # Reduction within each node\n            # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n            # The output format is the same as the fp32 master parameters\n            works = [None]*self._num_chunks\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]\n                rs_stream.wait_stream(torch.cuda.current_stream())\n                rs_stream.wait_stream(self._l2_grad_norm_st)\n                with torch.cuda.stream(rs_stream):\n                    works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)\n\n            # Reduction across nodes for each rank\n            if self._num_groups > 1:\n                for chunk_id in range(self._num_chunks):\n                    glob_chunk_id = block_id * self._num_chunks + chunk_id\n                    ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]\n                    with torch.cuda.stream(ar_stream):\n                        works[chunk_id].wait()\n                        works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)\n            self._reductions_works[block_id] = works\n\n            if block_id == 0:\n                for block_id in range(self._num_blocks):\n                    for chunk_id in range(self._num_chunks):\n                        self._reductions_works[block_id][chunk_id].wait()\n\n    def __compute_contrib_param_norm(self):\n        if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:\n            gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]\n            gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]\n            gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')\n            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)\n            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)\n        elif self._contrib_model_param_for_norm_fp16 is not None:\n            gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]\n        elif self._contrib_model_param_for_norm_fp32 is not None:\n            gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]\n        return gnorm\n\n    def __compute_contrib_update_norm(self):\n        l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')\n        local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2\n        l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)\n        torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])\n        l2_norm = torch.sqrt(l2_norm)\n        return l2_norm\n\n    def _pipeline_step(self):\n        global_scale = self.global_scale\n        # if clip before ar, set max_grad_norm to 0\n        max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar\n        self._completion_st.wait_stream(self._l2_grad_norm_st)\n        global_grad_norm = self.L2_grad_norm\n\n        # check global_grad_norm and fill overflow_buf\n        is_finite = (global_grad_norm + 1 > global_grad_norm).int()\n        self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1\n        torch.distributed.all_reduce(is_finite,\n                                     op=torch.distributed.ReduceOp.MIN,\n                                     group=self._current_process_group)\n        torch.distributed.all_reduce(self._overflow_buf,\n                                     op=torch.distributed.ReduceOp.MAX,\n                                     group=self._current_process_group)\n\n        # increment step counter if no overflow\n        self._step += is_finite\n        self._completion_st.wait_stream(torch.cuda.current_stream())\n        self._completion_st.wait_stream(self._l2_grad_norm_st)\n\n        # Call step kernel once per step\n        # Call all-gather once per step\n        with torch.cuda.stream(self._completion_st):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    self._reductions_works[block_id][chunk_id].wait()\n            param_norm = self.__compute_contrib_param_norm()\n            multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,\n                    self._overflow_buf,\n                    self._contrib_compute_update_term_tensor_list, # g, p, m, v, u\n                    self._contrib_beta1,\n                    self._contrib_beta2,\n                    self._contrib_beta3,\n                    self._contrib_bias_correction,\n                    self._step,\n                    self._contrib_epsilon,\n                    self._adam_w_mode,\n                    self._contrib_weight_decay,\n                    global_scale,\n                    global_grad_norm,\n                    max_grad_norm)\n            upd_norm = self.__compute_contrib_update_norm()\n            multi_tensor_applier(self.multi_tensor_lamb_update_weights,\n                    self._overflow_buf,\n                    self._contrib_update_weights_tensor_list, # u, p, p_copy\n                    param_norm,\n                    upd_norm,\n                    self._offsets,\n                    self._lr,\n                    self._contrib_weight_decay,\n                    global_grad_norm,\n                    self._use_nvlamb)\n            torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)\n\n    def _flatten_grad_mt(self, scale):\n        if len(self._grads_fp16) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp16)),\n                    scale)\n            self._grads_fp16 = []\n        if len(self._grads_fp32) > 0:\n            self._overflow_buf.zero_()\n            multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp32)),\n                    scale)\n            self._grads_fp32 = []\n\n    def _do_overlapped_reduction(self, param_i, param):\n        if not self._is_accumulation_step:\n            # handle overlapped reductions\n            if param.dtype == torch.float16:\n                self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )\n            else:\n                self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )\n            self._grads_generated[param_i]=True\n            if not self._first_step and not self._last_step:\n                if self._overlap_reductions:\n                    flush_block = self._get_flush_block()\n                    while flush_block:\n                        block_id = flush_block[0] // self._block_size\n                        self._pipeline_block_reductions(block_id)\n                        flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\n        \"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def L2_grad_norm(self):\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n        return self._L2_grad_norm\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\n        \"\"\"\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, grad_generated in enumerate(self._grads_generated):\n                if not grad_generated:\n                    grad_info = self._grads_info[param_i]\n                    param_offset = grad_info[\"param_offset\"]\n                    param_size = grad_info[\"param_grads_size\"]\n                    self._flat_grads[param_offset:param_offset+param_size].zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._first_step or self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            for block_id in range(self._num_blocks-1,-1,-1):\n                self._pipeline_block_reductions(block_id)\n\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False]*len(self._grads_info)\n\n    def step(self, closure=None, grad_scaler=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        self._pipeline_step()\n\n        if grad_scaler is not None:\n            found_inf = self._overflow_buf.float()\n            optimizer_state = grad_scaler._per_optimizer_states[id(self)]\n            current_device = torch.device('cuda', torch.cuda.current_device())\n            optimizer_state[\"found_inf_per_device\"][current_device] = found_inf\n\n        self._completion_st.wait_stream(torch.cuda.current_stream())\n\n        with torch.cuda.stream(self._completion_st):\n            # Copy self._new_params to model params\n            with torch.no_grad():\n                if self._packed_flat_to_model_params_fp16 is not None:\n                    multi_tensor_applier(\n                            fused_adam_cuda.maybe_cast_mt,\n                            self._overflow_buf,\n                            self._packed_flat_to_model_params_fp16)\n                if self._packed_flat_to_model_params_fp32 is not None:\n                    multi_tensor_applier(\n                            fused_adam_cuda.maybe_cast_mt,\n                            self._overflow_buf,\n                            self._packed_flat_to_model_params_fp32)\n\n        torch.cuda.current_stream().wait_stream(self._completion_st)\n\n        self._reductions_works = [None]*self._num_blocks\n        self._allgather_works = [None]*self._num_blocks\n\n        return loss\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        # save step, master weights and first/second moments\n        state_dict = {}\n        state_dict['step'] = self._step\n        state_dict['fp32_p'] = self._fp32_p\n        state_dict['fp32_m'] = self._fp32_m\n        state_dict['fp32_v'] = self._fp32_v\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``optimizer.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # restore step, master weights and first/second moments\n        self._step = state_dict['step']\n        self._fp32_p = state_dict['fp32_p'].to(device=\"cuda\")\n        self._fp32_m = state_dict['fp32_m'].to(device=\"cuda\")\n        self._fp32_v = state_dict['fp32_v'].to(device=\"cuda\")\n        self._resume_from_checkpoint = True\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/fp16_optimizer.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FP16_Optimizer(object):\n    \"\"\"\n    :class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.\n    Designed only to wrap apex.contrib.optimizers.FusedAdam, FusedSGD.\n    Refer to apex.fp16_utils documents for more information.\n    Example::\n        model = torch.nn.Linear(D_in, D_out).cuda().half()\n        optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())\n        optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n        ...\n        # loss.backward() becomes:\n        optimizer.backward(loss)\n        ...\n    Example with dynamic loss scaling::\n        ...\n        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)\n                                   # optional arg to control dynamic loss scaling behavior\n                                   # dynamic_loss_args={'scale_window' : 500})\n                                   # Usually, dynamic_loss_args is not necessary.\n    \"\"\"\n\n    def __init__(self,\n                 init_optimizer,\n                 static_loss_scale=1.0,\n                 dynamic_loss_scale=False,\n                 dynamic_loss_args=None,\n                 verbose=True):\n\n        print(\"\\nThis fp16_optimizer is designed to only work with apex.contrib.optimizers.*\")\n        print(\"To update, use updated optimizers with AMP.\")\n        # The fused optimizer does all the work. We need this layer for two reason:\n        # 1. maintain same user API from apex.fp16_utils\n        # 2. keep common stuff here in case we need to add new fused optimizer later\n\n        if not torch.cuda.is_available:\n            raise SystemError(\"Cannot use fp16 without CUDA.\")\n        self.optimizer = init_optimizer\n\n        self.fp16_groups = [] # model params\n        self.fp32_groups = [] # master weights\n\n        # iterate over param_groups\n        for param_group in self.optimizer.param_groups:\n            fp16_group = []\n            fp32_group = []\n            for p in param_group['params']:\n                fp16_group.append(p)\n                fp32_group.append(p.clone().float().detach())\n            self.fp16_groups.append(fp16_group)\n            self.fp32_groups.append(fp32_group)\n            param_group['params'] = fp32_group\n\n        if multi_tensor_applier.available:\n            import amp_C\n            self.overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm\n        else:\n            raise RuntimeError('FP16_Optimizer requires cuda extensions')\n\n        # we may have a way of fusing dynamic scale. Do not support for now\n        if dynamic_loss_scale:\n            if dynamic_loss_args is not None:\n                raise SystemError(\"Do not support dynamic loss scale args for now.\")\n            self.dynamic_loss_scale = True\n            self.cur_scale = 2**16\n            self.cur_iter = 0\n            self.last_overflow_iter = -1\n            self.scale_factor = 2\n            self.scale_window = 1000\n        else:\n            self.dynamic_loss_scale = False\n            self.cur_iter = 0\n            self.cur_scale = static_loss_scale\n        self.verbose = verbose\n\n    def zero_grad(self, set_grads_to_None=True):\n        \"\"\"\n        Zero FP16 parameter grads.\n        \"\"\"\n        # FP32 grad should never exist.\n        # For speed, set model fp16 grad to None by default\n        for group in self.fp16_groups:\n            for p in group:\n                if set_grads_to_None:\n                    p.grad = None\n                else:\n                    if p.grad is not None:\n                        p.grad.detach_()\n                        p.grad.zero_()\n\n    def step(self, closure=None):\n        \"\"\"\n        Not supporting closure.\n        \"\"\"\n        fp16_grads = []\n        norm_groups = []\n        skip = False\n\n        for group in self.fp16_groups:\n            fp16_grad = []\n            for i, p in enumerate(group):\n                fp16_grad.append(p.grad)\n            fp16_grads.append(fp16_grad)\n        \n        # nan check\n        self.overflow_buf.zero_()\n        for fp16_grad in fp16_grads:\n            if len(fp16_grad) > 0:\n                norm, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,\n                                                             self.overflow_buf,\n                                                             [fp16_grad], True)\n                norm_groups.append(norm)\n                if self.overflow_buf.item() != 0:\n                    skip = True\n\n        if skip:\n            self._update_scale(skip)\n            return\n\n        # norm is in fact norm*cur_scale\n        self.optimizer.step(grads=fp16_grads,\n                            output_params=self.fp16_groups,\n                            scale=self.cur_scale,\n                            grad_norms=norm_groups)\n\n        self._update_scale(False)\n        return\n\n    def backward(self, loss):\n        \"\"\"\n        :attr:`backward` performs the following steps:\n        1. fp32_loss = loss.float()\n        2. scaled_loss = fp32_loss*loss_scale\n        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves\n        \"\"\"\n        scaled_loss = (loss.float()) * self.cur_scale\n        scaled_loss.backward()\n\n    def _update_scale(self, skip):\n        if self.dynamic_loss_scale:\n            if skip:\n                if self.verbose:\n                    print(\"\\nGrad overflow on iteration\", self.cur_iter)\n                    print(\"Using dynamic loss scale of\", self.cur_scale)\n                self.cur_scale = max(self.cur_scale/self.scale_factor, 1)\n                self.last_overflow_iter = self.cur_iter\n            else:\n                if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:\n                    self.cur_scale *= self.scale_factor\n        else:\n            if skip:\n                print(\"\\nGrad overflow on iteration\", self.cur_iter)\n                print(\"Using static loss scale of\", self.cur_scale)\n        self.cur_iter +=1\n        return\n\n    # Promote state so it can be retrieved or set via \"fp16_optimizer_instance.state\"\n    def _get_state(self):\n        return self.optimizer.state\n\n    def _set_state(self, value):\n        self.optimizer.state = value\n\n    state = property(_get_state, _set_state)\n\n    # Promote param_groups so it can be retrieved or set via \"fp16_optimizer_instance.param_groups\"\n    # (for example, to adjust the learning rate)\n    def _get_param_groups(self):\n        return self.optimizer.param_groups\n\n    def _set_param_groups(self, value):\n        self.optimizer.param_groups = value\n\n    param_groups = property(_get_param_groups, _set_param_groups)\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.\n        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict\n        of the contained Pytorch optimizer.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        state_dict = {}\n        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale\n        state_dict['cur_scale'] = self.cur_scale\n        state_dict['cur_iter'] = self.cur_iter\n        if state_dict['dynamic_loss_scale']:\n            state_dict['last_overflow_iter'] = self.last_overflow_iter\n            state_dict['scale_factor'] = self.scale_factor\n            state_dict['scale_window'] = self.scale_window\n        state_dict['optimizer_state_dict'] = self.optimizer.state_dict()\n        state_dict['fp32_groups'] = self.fp32_groups\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``fp16_optimizer_instance.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # I think it should actually be ok to reload the optimizer before the model.\n        self.dynamic_loss_scale = state_dict['dynamic_loss_scale']\n        self.cur_scale = state_dict['cur_scale']\n        self.cur_iter = state_dict['cur_iter']\n        if state_dict['dynamic_loss_scale']:\n            self.last_overflow_iter = state_dict['last_overflow_iter']\n            self.scale_factor = state_dict['scale_factor']\n            self.scale_window = state_dict['scale_window']\n        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])\n        # At this point, the optimizer's references to the model's fp32 parameters are up to date.\n        # The optimizer's hyperparameters and internal buffers are also up to date.\n        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still\n        # out of date.  There are two options.\n        # 1:  Refresh the master params from the model's fp16 params.\n        # This requires less storage but incurs precision loss.\n        # 2:  Save and restore the fp32 master copies separately.\n        # We choose option 2.\n        #\n        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device\n        # of their associated parameters, because it's possible those buffers might not exist yet in\n        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been\n        # constructed in the same way as the one whose state_dict we are loading, the same master params\n        # are guaranteed to exist, so we can just copy_() from the saved master params.\n        for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']):\n            for _current, _saved in zip(current, saved):\n                _current.data.copy_(_saved.data)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/fused_adam.py",
    "content": "import types\nimport torch\nimport importlib\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedAdam(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        use_mt (boolean, optional): use multi tensor apply for lower launch\n            latency. (default: False)\n\n    .. _Adam - A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params,\n                 lr=1e-3, bias_correction = True,\n                 betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,\n                 weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,\n                 amp_scale_adjustment=1.0):\n        global fused_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n\n        self._use_multi_tensor = False\n        if use_mt:\n            if not multi_tensor_applier.available:\n                print(\"Warning:  multi_tensor_applier is unavailable\")\n            else:\n                self._use_multi_tensor = True\n                self._overflow_buf = torch.cuda.IntTensor([0])\n\n        self._amp_scale_adjustment = amp_scale_adjustment\n\n        if amsgrad:\n            raise RuntimeError('FusedAdam does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm)\n        super(FusedAdam, self).__init__(params, defaults)\n        self.eps_mode = 0 if  eps_inside_sqrt else 1\n\n    def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n            grads (list of tensors, optional): weight gradient to use for the\n                optimizer update. If gradients have type torch.half, parameters\n                are expected to be in type torch.float. (default: None)\n            output params (list of tensors, optional): A reduced precision copy\n                of the updated weights written out in addition to the regular\n                updated weights. Have to be of same type as gradients. (default: None)\n            scale (float, optional): factor to divide gradient tensor values\n                by before applying to weights. (default: 1)\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if hasattr(self, \"_amp_stash\"):\n            grads = self._amp_stash.grads\n            output_params = self._amp_stash.output_params\n            scale = self._amp_stash.scale*self._amp_scale_adjustment\n            grad_norms = self._amp_stash.grad_norms\n\n        if grads is None:\n            grads_group = [None]*len(self.param_groups)\n        # backward compatibility\n        # assuming a list/generator of parameter means single group\n        elif isinstance(grads, types.GeneratorType):\n            grads_group = [grads]\n        elif type(grads[0])!=list:\n            grads_group = [grads]\n        else:\n            grads_group = grads\n\n        if output_params is None:\n            output_params_group = [None]*len(self.param_groups)\n        elif isinstance(output_params, types.GeneratorType):\n            output_params_group = [output_params]\n        elif type(output_params[0])!=list:\n            output_params_group = [output_params]\n        else:\n            output_params_group = output_params\n\n        if grad_norms is None:\n            grad_norms = [None]*len(self.param_groups)\n\n        for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):\n            if grads_this_group is None:\n               grads_this_group = [None]*len(group['params'])\n            if output_params_this_group is None:\n               output_params_this_group = [None]*len(group['params'])\n\n            # compute combined scale factor for this group\n            combined_scale = scale\n            if group['max_grad_norm'] > 0:\n                # norm is in fact norm*scale\n                clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']\n                if clip > 1:\n                    combined_scale = clip * scale\n\n            bias_correction = 1 if group['bias_correction'] else 0\n\n            if self._use_multi_tensor:\n                if output_params:\n                    tensorlists = [[],[],[],[],[]]\n                else:\n                    tensorlists = [[],[],[],[]]\n                tensordevice = None\n\n            for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):\n                #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients\n                if p.grad is None and grad is None:\n                    continue\n                if grad is None:\n                    grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param\n                if self._use_multi_tensor:\n                    pl = [p.data, exp_avg, exp_avg_sq, grad]\n                    if output_param is not None:\n                        pl.append(out_p)\n\n                    for tl, t in zip(tensorlists, pl):\n                        tl.append(t)\n\n                    if tensordevice is None:\n                        tensordevice = p.device\n                    elif tensordevice != p.device:\n                        raise RuntimeError('FusedAdam does not support use_mt with tensors on multiple device')\n\n                else:\n                    with torch.cuda.device(p.device):\n                        fused_adam_cuda.adam(p.data,\n                                             out_p,\n                                             exp_avg,\n                                             exp_avg_sq,\n                                             grad,\n                                             group['lr'],\n                                             beta1,\n                                             beta2,\n                                             group['eps'],\n                                             combined_scale,\n                                             state['step'],\n                                             self.eps_mode,\n                                             bias_correction,\n                                             group['weight_decay'])\n\n            if self._use_multi_tensor:\n                with torch.cuda.device(tensordevice):\n                    multi_tensor_applier(\n                        fused_adam_cuda.adam_mt,\n                        self._overflow_buf,\n                        tensorlists,\n                        group['lr'],\n                        beta1,\n                        beta2,\n                        group['eps'],\n                        combined_scale,\n                        state['step'],\n                        self.eps_mode,\n                        bias_correction,\n                        group['weight_decay'])\n\n        return loss\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/fused_lamb.py",
    "content": "import torch\nimport importlib\nimport math\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedLAMB(torch.optim.Optimizer):\n\n    \"\"\"Implements LAMB algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--deprecated_fused_lamb\" ./``.\n\n    This version of fused LAMB implements 2 fusions.\n\n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n\n        opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n\n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,\n                 amsgrad=False, adam_w_mode=True,\n                 grad_averaging=True, set_grad_none=True,\n                 max_grad_norm=1.0):\n        if amsgrad:\n            raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging,\n                        max_grad_norm=max_grad_norm)\n        super(FusedLAMB, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            fused_lamb_cuda = importlib.import_module(\"fused_lamb_cuda\")\n            self.multi_tensor_lamb = fused_lamb_cuda.lamb\n        else:\n            raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions')\n\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedLAMB, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # create separate grad lists for fp32 and fp16 params\n        g_all_32, g_all_16 = [], []\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.dtype == torch.float32:\n                    g_all_32.append(p.grad.data)\n                elif p.dytpe == torch.float16:\n                    g_all_16.append(p.grad.data)\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n        g_norm_32, g_norm_16 = 0.0, 0.0\n        # compute grad norm for two lists\n        if len(g_all_32) > 0:\n            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_32], False)[0].item()\n        if len(g_all_16) > 0:\n            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_16], False)[0].item()\n\n        # blend two grad norms to get global grad norm\n        global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)\n        max_grad_norm = self.defaults['max_grad_norm']\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n            grad_averaging = 1 if group['grad_averaging'] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                    v_16.append(state['exp_avg_sq'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                    v_32.append(state['exp_avg_sq'])\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16, v_16],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm)\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32, v_32],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm)\n\n        return loss\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/optimizers/fused_sgd.py",
    "content": "import types\nimport torch\nfrom torch.optim.optimizer import Optimizer, required\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedSGD(Optimizer):\n    r\"\"\"Implements stochastic gradient descent (optionally with momentum).\n\n    This version of fused SGD implements 2 fusions.\n      * Fusion of the SGD update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.contrib.optimizers.FusedSGD` should be used without AMP.\n   \n    :class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad. \n\n    Nesterov momentum is based on the formula from\n    `On the importance of initialization and momentum in deep learning`__.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float): learning rate\n        momentum (float, optional): momentum factor (default: 0)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        dampening (float, optional): dampening for momentum (default: 0)\n        nesterov (bool, optional): enables Nesterov momentum (default: False)\n\n    Example:\n        model = ...\n        model.half()\n        optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())\n        # wrap with FP16_Optimizer\n        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)\n        optimizer.zero_grad()\n\t...\n        optimizer.backward(loss)\n        optmizer.step()\n\n    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf\n\n    .. note::\n        The implementation of SGD with Momentum/Nesterov subtly differs from\n        Sutskever et. al. and implementations in some other frameworks.\n\n        Considering the specific case of Momentum, the update can be written as\n\n        .. math::\n                  v = \\rho * v + g \\\\\n                  p = p - lr * v\n\n        where p, g, v and :math:`\\rho` denote the parameters, gradient,\n        velocity, and momentum respectively.\n\n        This is in contrast to Sutskever et. al. and\n        other frameworks which employ an update of the form\n\n        .. math::\n             v = \\rho * v + lr * g \\\\\n             p = p - v\n\n        The Nesterov version is analogously modified.\n    \"\"\"\n\n    def __init__(self, params, lr=required, momentum=0, dampening=0,\n                 weight_decay=0, nesterov=False,\n                 wd_after_momentum=False,\n                 materialize_master_grads=True):\n        if lr is not required and lr < 0.0:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if momentum < 0.0:\n            raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n        if weight_decay < 0.0:\n            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n\n        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,\n                        weight_decay=weight_decay, nesterov=nesterov)\n        if nesterov and (momentum <= 0 or dampening != 0):\n            raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n        super(FusedSGD, self).__init__(params, defaults)\n\n        self.wd_after_momentum = wd_after_momentum\n\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_sgd = amp_C.multi_tensor_sgd\n        else:\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD requires cuda extensions')\n\n    def __setstate__(self, state):\n        super(FusedSGD, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('nesterov', False)\n\n    def get_momentums(self, params):\n        momentums = []\n        first_run = True\n        for p in params:\n            param_state = self.state[p]\n            # torch.optim.SGD initializes momentum in the main loop, we have\n            # to do it here, and track whether or not we've done so, so that\n            # momentum application can be skipped in the main kernel.\n            if 'momentum_buffer' not in param_state:\n                first_run = True\n                buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)\n                momentums.append(buf)\n            else:\n                first_run = False\n                momentums.append(param_state['momentum_buffer'])\n        return momentums, first_run\n    \n    def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):\n        \"\"\"Performs a single optimization step.\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n            grads (list of tensors, optional): weight gradient to use for the\n                optimizer update. If gradients have type torch.half, parameters\n                are expected to be in type torch.float. (default: None)\n            output_params (list of tensors, optional): A reduced precision copy\n                of the updated weights written out in addition to the regular\n                updated weights. Have to be of same type as gradients. (default: None)\n            scale (float, optional): factor to divide gradient tensor values\n                by before applying to weights. (default: 1)\n        \"\"\"\n        if hasattr(self, \"_amp_stash\"):\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD should not be used with AMP.')\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if grads is None:\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \\\n\t                       with apex.contrib.optimizers.FP16_Optimizer \\\n\t\t\t       which provides grads.')\n        # backward compatibility\n        # assuming a list/generator of parameter means single group\n        elif isinstance(grads, types.GeneratorType):\n            grads_group = [grads]\n        elif type(grads[0])!=list:\n            grads_group = [grads]\n        else:\n            grads_group = grads\n\n        if output_params is None:\n            raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \\\n                               with apex.contrib.optimizers.FP16_Optimizer \\\n                               which provides output_params.')\n        elif isinstance(output_params, types.GeneratorType):\n            output_params_group = [output_params]\n        elif type(output_params[0])!=list:\n            output_params_group = [output_params]\n        else:\n            output_params_group = output_params\n\n        for group, grads_this_group, output_params_this_group in zip(self.param_groups, \n\t                                                             grads_group, \n                                                                     output_params_group):\n            if grads_this_group is None or output_params_this_group is None: \n                raise RuntimeError('apex.contrib.optimizers.FusedSGD only works \\\n                                    when all parameters require grad.')\n            \n            weight_decay = group['weight_decay']\n            momentum = group['momentum']\n            dampening = group['dampening']\n            nesterov = group['nesterov']\n            lr = group['lr']\n\n            first_runs = [True, True]\n            \n            # output_params_this_group: original weights (either fp16 or fp32)\n            # group['params']: master weights (fp32)\n\n            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy\n            # fp32, fp32, fp32, No\n            fp32_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float32]\n            fp32_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float32]\n            fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n            fp32_set = [fp32_grads, fp32_params, fp32_momentums]\n\n            # fp16, fp32, fp32, Yes\n            fp16_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float16]\n            fp32_from_fp16_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]\n            fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)\n            fp16_params = [p1 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]\n            fp16_set = [fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_params]\n\n            launch_sets = [fp16_set, fp32_set]\n\n            for launch_set, first_run in zip(launch_sets, first_runs):\n                assert len(launch_set[0]) == len(launch_set[1])\n                assert len(launch_set[0]) == len(launch_set[2])\n                if len(launch_set[0]) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_sgd,\n                        self._dummy_overflow_buf,\n                        launch_set,\n                        weight_decay,\n                        momentum,\n                        dampening,\n                        lr,\n                        nesterov,\n                        first_run,\n                        self.wd_after_momentum,\n                        1.0/scale)\n\n        return loss\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/README.md",
    "content": "# Introduction to ASP\n\nThis serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.\n\n## Importing ASP\n```\nfrom apex.contrib.sparsity import ASP\n```\n\n## Initializing ASP\n\nApart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:\n```\nASP.prune_trained_model(model, optimizer)\n```\n\nIn the context of a typical PyTorch training loop, it might look like this:\n```\nASP.prune_trained_model(model, optimizer)\n\nx, y = DataLoader(args)\nfor epoch in range(epochs):\n    y_pred = model(x)\n    loss = loss_function(y_pred, y)\n    loss.backward()\n    optimizer.step()\n\ntorch.save(...)\n```\nThe `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. \n\n## Generate a Sparse Network\n\nThe following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.\n\n```\n(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.\n(2) Fine-tune  the  pruned  model  with  optimization  method  and  hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.\n(3) (If required) Quantize the model.\n```\n\nIn code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).\n\n```\n\nmodel = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)\ncriterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model\noptimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model\nlr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model\n\nfrom apex.contrib.sparsity import ASP     \nASP.prune_trained_model(model, optimizer) #pruned a trained model\n\nx, y = DataLoader(args)\nfor epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model\n    y_pred = model(x)\n    loss = criterion(y_pred, y)\n    lr_scheduler.step()\n    loss.backward()\n    optimizer.step()\n\ntorch.save(...) # saves the pruned checkpoint with sparsity masks \n```\n\n## Non-Standard Usage\n\nIf your goal is to easily perpare a network for accelerated inference, please follow the recipe above.  However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:\n\n```\nASP.compute_sparse_masks()\n```\n\nA more thorough example can be found in `./test/toy_problem.py`. \n\n\n\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/__init__.py",
    "content": "from .sparse_masklib import create_mask\nfrom .asp import ASP\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/asp.py",
    "content": "import types\nimport torch\nfrom .sparse_masklib import create_mask\n\ntorchvision_imported=True\ntry:\n    import torchvision\nexcept ImportError:\n    print(\"[ASP][Warning] torchvision cannot be imported.\")\n    torchvision_imported=False\n\ndef eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):\n    eligible_modules_list = []\n    for name, mod in model.named_modules():\n        if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:\n            if allowed_layer_names is not None and name not in allowed_layer_names:\n                continue\n            eligible_modules_list.append((name, mod))\n    return eligible_modules_list\n\nclass ASP:\n    __model = None\n    __verbosity = 0\n    __optimizer = None\n    __sparse_parameters = []\n    __calculate_mask = None\n\n    @classmethod\n    def init_model_for_pruning(cls, model, mask_calculator=\"m4n2_1d\",\n             verbosity=3,\n             whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], \n             allowed_layer_names=None, disallowed_layer_names=[],\n             allow_recompute_mask=False, custom_layer_dict={}):\n        \"\"\"Call this method to modify your model to take advantage of sparse matrix multiplication.\n        Note that this call alone only augments the model with additional buffers needed for sparse MMA,\n        it does not enable use of sparse MMA. \n\n        If you are starting with a fresh model:\n\n        model = ...\n        ASP.init_model_for_pruning(model, mask_calculator, ...)\n        if (training) ASP.init_optimizer_for_pruning(optimizer)\n        ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.\n\n        If you are starting from a checkpoint:\n\n        model = ...\n        ASP.init_model_for_pruning(model, mask_calculator, ...)\n        torch.load(...)\n        if (training) ASP.init_optimizer_for_pruning(optimizer)\n\n        Arguments:\n          model                    The model\n          mask_calculator          Either callable that computes mask given a tensor OR pattern string for sparse mask lib.\n          verbosity                Integer controling verbosity level.\n                                   0 -> Only errors.\n                                   1 -> Errors and warnings.\n                                   2 -> Errors, warnings and info.\n                                   3 -> Errors, warnings, info and debug.\n          whitelist                Module types approved for sparsity.\n          allowed_layer_names      If not None, only layer names that appear in this list are considered for sparsity.\n          disallowed_layer_names   If not [], only layer names that do not appear in this list are considered for sparsity.\n          allow_recompute_mask     If True, stores pruned values so that dense weights can be restored.\n                                   Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.\n          custom_layer_dict        Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}\n          \n          [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM. \n        \"\"\"\n        assert (cls.__model is None), \"ASP has been initialized already.\"\n        cls.__model = model\n        cls.__verbosity = verbosity\n\n        if isinstance(mask_calculator, str):\n            def create_mask_from_pattern(param):\n                return create_mask(param, mask_calculator).bool()\n            cls.__calculate_mask = create_mask_from_pattern\n        else:\n            cls.__calculate_mask = mask_calculator #user defined function\n\n        # function to extract variables that will be sparsified. \n        # idea is that you will add one of these functions for each module type that can be sparsified.\n        if torchvision_imported:\n            print(\"[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.\")\n            sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}\n        else:\n            sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}\n        if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune\n            sparse_parameter_list.update(custom_layer_dict)\n            whitelist += list(custom_layer_dict.keys())\n\n        for module_type in whitelist:\n            assert (module_type in sparse_parameter_list), \"Module %s :: Don't know how to sparsify module.\" % module.dtype()\n\n        # find all sparse modules, extract sparse parameters and decorate\n        def add_sparse_attributes(module_name, module):\n            sparse_parameters = sparse_parameter_list[type(module)]\n            for p_name, p in module.named_parameters():\n                if p_name in sparse_parameters and p.requires_grad:\n                    # check for NVIDIA's TC compatibility: we check along the horizontal direction\n                    if p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #User defines FP32 and APEX internally uses FP16 math\n                        print(\"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n                        continue\n                    if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C\n                        print(\"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n                        continue\n                    \n                    if cls.__verbosity >= 3:\n                        print(\"[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n                    \n                    mask = torch.ones_like(p).bool()\n                    buffname = p_name.split(\".\")[-1] # buffer names cannot contain \".\"\n                    module.register_buffer('__%s_mma_mask' % buffname, mask)\n                    if allow_recompute_mask:\n                        pruned = torch.zeros_like(p).cpu()\n                        module.register_buffer('__%s_mma_pruned_p' % buffname, pruned)\n                    else:\n                        pruned = None\n                    cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))\n                else:\n                    if cls.__verbosity >= 3:\n                        print(\"[ASP] Not sparsifying %s::%s of size=%s and type=%s\" % (module_name, p_name, str(p.size()), str(p.dtype)))\n\n        for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):\n            add_sparse_attributes(name, sparse_module)\n\n    @classmethod\n    def init_optimizer_for_pruning(cls, optimizer):\n        \"\"\"Call this method to monkey patch optimizer step function so that masks can be applied to\n        gradients and weights during training.\n        You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)\n        \"\"\"\n        assert (cls.__optimizer is None), \"ASP has initialized optimizer already.\"\n        assert (cls.__calculate_mask is not None), \"Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning.\"\n\n        # store pointer to original optimizer step method\n        cls.__optimizer = optimizer\n        cls.__optimizer.__step = optimizer.step\n\n        def __step(opt_self, *args, **kwargs):\n            # prune gradients before step method\n            with torch.no_grad():\n                for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                    if p.grad is not None: #thx pjudd\n                        p.grad.mul_(mask)\n            # call original optimizer step method\n            rval = opt_self.__step(*args, **kwargs)\n            # prune parameters after step method\n            with torch.no_grad():\n                for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                    p.mul_(mask)\n            return rval\n        cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)\n\n    @classmethod\n    def compute_sparse_masks(cls):\n        \"\"\"Call this method to enable sparsity.\n        If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.\n        \"\"\"\n        with torch.no_grad():\n            for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                if mask.sum() < mask.numel(): # when recalculating masks\n                    # restore dense parameter if allow_recompute_mask is enabled\n                    assert (pruned is not None), \"Unable to restore dense parameter because allow_recompute_mask == False\"\n                    p.add_(pruned.cuda())\n\n                mask.set_(cls.__calculate_mask(p))\n\n                if pruned is not None: # stow away pruned weights to cpu\n                    pruned.set_((p * (~mask)).cpu())\n\n                p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights\n                if cls.__verbosity >= 2:\n                    print(\"[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s\" % (100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))\n\n    @classmethod\n    def restore_pruned_weights(cls):\n        \"\"\"Call this method to disable sparsity and restore all weights.\n        This will only work if init(...) was called with allow_recompute=True.\n        \"\"\"\n        with torch.no_grad():\n            for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                if mask.sum() < mask.numel():\n                    assert (pruned is not None), \"Unable to restore dense parameter because allow_recompute_mask == False\"\n                    p.add_(pruned.cuda())\n                    mask.fill_(1)\n                    pruned.zero_()\n                    if cls.__verbosity >= 2:\n                        print(\"[ASP] Disabled sparsity for %s::%s (dense weights restored)\" % (module_name, p_name))\n\n    @classmethod\n    def is_sparsity_enabled(cls):\n        \"\"\"Call this method to determine if sparsity is enabled in the model.\n        The typical use case is right after checkpoint has been loaded.\n        \"\"\"\n        total,sp100,sp50 = 0,0,0\n        for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n            total += 1\n            mask_sum = mask.sum()\n            mask_numel = mask.numel()\n            if mask_sum == mask_numel:\n                sp100 += 1\n            elif mask_sum*2 == mask_numel:\n                sp50 += 1\n\n        assert (total == sp100 or total == sp50), \"Inconsistent model sparsity\"\n        if total == sp100:\n            return False\n        elif total == sp50:\n            return True\n    \n    @classmethod\n    def prune_trained_model(cls, model, optimizer):\n        # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)\n        cls.init_model_for_pruning(model, mask_calculator=\"m4n2_1d\", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)\n        cls.init_optimizer_for_pruning(optimizer)\n        cls.compute_sparse_masks()\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/sparse_masklib.py",
    "content": "import sys\nimport torch\nimport numpy as np\nimport collections\nfrom itertools import permutations\n\n\n\"\"\" compute density (helper fn to compute % NNZs in a tensor) \"\"\"\ndef fill(x):\n    return float(x.nonzero().size(0))/torch.numel(x)\n\n\"\"\" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) \"\"\"\ndef reshape_1d(matrix, m):\n    # If not a nice multiple of m, fill with zeroes.\n    if matrix.shape[1] % m > 0:\n        mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m-matrix.shape[1]%m)).fill_(0)\n        mat[:, :matrix.shape[1]] = matrix\n        shape = mat.shape\n        return mat.view(-1,m),shape\n    else:\n        return matrix.view(-1,m), matrix.shape\n\n\"\"\" return all possible m:n patterns in a 1d vector \"\"\"\nvalid_m4n2_1d_patterns = None\ndef compute_valid_1d_patterns(m,n):\n    # Early exit if patterns was already created.\n    global valid_m4n2_1d_patterns\n\n    if m==4  and n==2 and valid_m4n2_1d_patterns  is not None: return valid_m4n2_1d_patterns\n    patterns = torch.zeros(m)\n    patterns[:n] = 1\n    valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))\n    if m == 4  and n == 2: valid_m4n2_1d_patterns  = valid_patterns       \n    return valid_patterns\n\n\"\"\" m:n 1d structured best \"\"\"\ndef mn_1d_best(matrix, m, n):\n    # Find all possible patterns.\n    patterns = compute_valid_1d_patterns(m,n).cuda()\n\n    # Find the best m:n pattern (sum of non-masked weights).\n    mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)\n    mat,shape = reshape_1d(matrix,m)\n    pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)\n    mask[:] = patterns[pmax[:]]\n    mask = mask.view(matrix.shape)\n    return mask\n\ndef m4n2_1d(mat, density):\n    return mn_1d_best(mat, 4, 2)\n\n\"\"\"\n  Below 2d-masking related code is targeted more for training (from scratch).\n  2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop\n  phase of training algorithm. Acceleration comes from using SpMMA instructions in\n  Tensor Cores of NVIDIA Ampere GPU Architecture \n  (note: this code does not do the acceleration, GPU kernels are required for this).\n  1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern\n  along the horizontal (logical) direction.\n  During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask\n  weight tensor such that their transposed versions are also 2:4 sparse along the\n  horizontal (logical) direction. Thus, with 2d pruning, weight tensors are \n  2:4 sparse along row and column directions.\n \"\"\"\n\n\"\"\" m:n 2d structured pruning: greedy method to select mask \"\"\"\ndef mn_2d_greedy(matrix, m, n):\n    # Convert to numpy\n    mat = matrix.cpu().detach().numpy()\n    mask = np.ones(mat.shape, dtype=int)\n\n    rowCount = int(mat.shape[0]/m) * m\n    colCount = int(mat.shape[1]/m) * m\n    for rowStartIdx in range(0, rowCount, m):\n        rowEndIdx = rowStartIdx + m\n        for colStartIdx in range(0, colCount, m):\n            colEndIdx = colStartIdx + m\n            matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]))\n            maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])\n            maskSub.fill(0.0)\n            matrixVecView = matrixSub.reshape(-1)\n            maskVecView   = maskSub.reshape(-1)\n            linearIdx = np.argsort(matrixVecView)\n            matrixIdx = [(int(x/m), x % m) for x in linearIdx]\n            rowCounter = collections.Counter()\n            colCounter = collections.Counter()\n            for currIdx in range(len(linearIdx) - 1, -1, -1):\n                currMatrixEntry = matrixIdx[currIdx]\n                if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n):\n                    continue\n                #end if\n                maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0\n                rowCounter[currMatrixEntry[0]] += 1\n                colCounter[currMatrixEntry[1]] += 1\n\n    return torch.tensor(mask.cuda())\n\ndef m4n2_2d_greedy(mat, density):\n    return mn_2d_greedy(mat, 4, 2)\n\n\"\"\" return all possible m:n patterns in a mxn block. \"\"\"\nvalid_m4n2_2d_patterns = None\ndef compute_valid_2d_patterns(m,n):\n    # Early exit if patterns was already created.\n    global valid_m4n2_2d_patterns\n    if valid_m4n2_2d_patterns is not None: return valid_m4n2_2d_patterns\n\n    patterns = torch.zeros(m)\n    patterns[:n] = 1\n    patterns = list(set(permutations(patterns.tolist())))\n    patterns = patterns + patterns\n    patterns = torch.Tensor(list(set(permutations(patterns,m))))\n\n    valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)\n    valid_patterns = torch.Tensor(valid.shape[0],m,m)\n    valid_patterns[:] = patterns[valid[:]]\n\n    if m == 4  and n == 2: valid_m4n2_2d_patterns  = valid_patterns\n    return valid_patterns\n\n\"\"\" m:n 2d structured pruning: exhaustive method to select best mask \"\"\"\ndef mn_2d_best(matrix, m, n):\n    # Find all possible patterns.\n    patterns = compute_valid_2d_patterns(m,n).cuda()\n\n    # Find the best m:n pattern (sum of non-masked weights).\n    mask = torch.cuda.IntTensor(matrix.shape).fill_(1)\n    mat = reshape_2d(matrix,m,m).abs()\n    pmax = torch.argmax(torch.matmul(mat,patterns.view(patterns.shape[0],m*m).t()), dim=2)\n\n    # Copy best m:n patterns into mask.\n    mat = mat.view(mat.shape[0]*mat.shape[1],-1)\n    pmax = pmax.view(pmax.shape[0]*pmax.shape[1]).unsqueeze(1).expand(-1,mat.shape[1])\n    patterns = patterns.view(patterns.shape[0],patterns.shape[1]*patterns.shape[2])\n    mat = torch.gather(patterns,0,pmax)\n    mat = reshape_2d_inv(mat.view(matrix.shape[0]//m,matrix.shape[1]//m,m,m))\n    mask.copy_(mat.type(mask.type()))\n    return mask\n\ndef m4n2_2d_best(mat, density):\n    return mn_2d_best(mat, 4, 2)\n\n\n\"\"\" returns a sparse mask \"\"\"\ndef create_mask(tensor, pattern=\"m4n2_1d\", density=0.5):\n    # Reshape tensor and mask.\n    shape = tensor.shape\n    ttype = tensor.type()\n    t = tensor.float().contiguous()\n\n    # 1d-tensor\n    if len(shape) == 1:\n        t = t.view(1, shape[0])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 2d-tensor (in, out)\n    elif len(shape) == 2:\n        t = t.view(shape[0], shape[1])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 3d-tensor (batch, in, out)\n    elif len(shape) == 3:\n        t = t.view(shape[0]*shape[1], shape[2])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 4d-tensor (in, out, h, w)\n    elif len(shape) == 4:\n        \"\"\"\n        # transformers (bmm)\n        t = t.view(shape[0]*shape[1]*shape[2], shape[3])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n        \"\"\"\n        # convs\n        t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous()      \n        return mask.view(shape).type(ttype)\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/test/checkpointing_test_part1.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(args):\n    #\n    # PART1\n    #\n\n    torch.manual_seed(args.seed)\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.enable_sparsity()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    torch.save({\n            'step': step,\n            'verbosity': args.verbosity,\n            'seed2': args.seed2,\n            'pattern': args.pattern,\n            'whitelist': args.whitelist,\n            'allow_recompute_mask': args.allow_recompute_mask,\n            'model_state_dict': model.state_dict(),\n            'optimizer_state_dict': optimizer.state_dict(),\n            }, args.checkpoint_path)\n\nif __name__ == '__main__':\n    class Args:\n        verbosity=3\n        seed = 4873\n        seed2 = 99875\n        pattern = \"m4n2_2d_best\"\n        whitelist = [torch.nn.Linear]\n        allow_recompute_mask = True\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/test/checkpointing_test_part2.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(step, args, model_state_dict, optimizer_state_dict):\n    #\n    # PART2\n    #\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    torch.manual_seed(args.seed2)\n    model.load_state_dict(model_state_dict)\n    optimizer.load_state_dict(optimizer_state_dict)\n\n    print(\"Model sparsity is %s\" % (\"enabled\" if ASP.sparsity_is_enabled() else \"disabled\"))\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\nif __name__ == '__main__':\n    checkpoint = torch.load(\"part1.chkp\")\n    class Args:\n        verbosity = checkpoint['verbosity']\n        seed = 4873\n        seed2 = checkpoint['seed2']\n        pattern = checkpoint['pattern']\n        whitelist = checkpoint['whitelist']\n        allow_recompute_mask = checkpoint['allow_recompute_mask']\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n    args = Args()\n\n    main(checkpoint['step'], args, checkpoint['model_state_dict'], checkpoint['optimizer_state_dict'])\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/test/checkpointing_test_reference.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\n#\n# Reference run for checkpointing test (part1 + part2)\n#\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(args):\n    #\n    # PART1\n    #\n\n    torch.manual_seed(args.seed)\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(model, args.pattern, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.enable_sparsity()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    #\n    # PART 2\n    #\n\n    torch.manual_seed(args.seed2)\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\nif __name__ == '__main__':\n    class Args:\n        seed = 4873\n        seed2 = 99875\n        pattern = \"m4n2_2d_best\"\n        whitelist = [torch.nn.Linear]\n        allow_recompute_mask = True\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/sparsity/test/toy_problem.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n        elif i == args.num_layers-1:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])\n        else:\n            od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)\n            od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])\n    return torch.nn.Sequential(od)\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target-target_batch)**2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    #print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\ndef main(args):\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    # only prune linear layers, even though we also support conv1d, conv2d and conv3d\n    ASP.init_model_for_pruning(model, \"m4n2_1d\", whitelist=[torch.nn.Linear], allow_recompute_mask=True)\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    # recompute sparse masks\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\n    # turn off sparsity\n    print(\"SPARSE :: \",one_ll)\n    ASP.restore_pruned_weights()\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \",one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps_2)\n\nif __name__ == '__main__':\n    class Args:\n        batch_size = 32\n        input_features = 16\n        output_features = 8\n        hidden_features = 40\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        num_dense_steps_2 = 1500\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/fmha/test_fmha.py",
    "content": "###############################################################################\n# Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n# \n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#     * Redistributions of source code must retain the above copyright\n#       notice, this list of conditions and the following disclaimer.\n#     * Redistributions in binary form must reproduce the above copyright\n#       notice, this list of conditions and the following disclaimer in the\n#       documentation and/or other materials provided with the distribution.\n#     * Neither the name of the NVIDIA CORPORATION nor the\n#       names of its contributors may be used to endorse or promote products\n#       derived from this software without specific prior written permission.\n# \n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n#\n###############################################################################\n\n\nimport sys\nimport torch\nimport numpy as np\nimport unittest\nimport math\n\nimport fmhalib as mha\n\ndef py_mha(qkv, amask, b, s, h, d):\n    qkv = qkv.view(b, s, h, 3, d)\n    q = qkv[:, :, :, 0, :].permute(0,2,1,3)\n    k = qkv[:, :, :, 1, :].permute(0,2,1,3)\n    v = qkv[:, :, :, 2, :].permute(0,2,1,3)\n    p = torch.matmul(q.float(), k.permute(0,1,3,2).float())\n    p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0\n    s = torch.softmax(p_masked, -1).to(qkv.dtype)\n    ctx = torch.matmul(s, v)\n    ctx = ctx.permute(0,2,1,3).contiguous()\n\n    ctx.retain_grad()\n\n    return ctx\n\nclass TestFMHA(unittest.TestCase):\n\n    def run_test(self, s, b):\n        print(f'Test s={s} b={b}')\n\n        torch.manual_seed(1234)\n        torch.cuda.manual_seed(1234)\n        \n        dtype = torch.float16\n        device = torch.device('cuda')\n\n        h = 16 \n        d = 64\n    \n        slens = [s] * b \n        a = torch.tensor(np.array([0] + slens), dtype=torch.int32)\n        amask = torch.ones(b,h,s,s, dtype=dtype, device=device)\n        seqlens = torch.tensor(slens, dtype=torch.int32, device=device)\n        cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)\n        total = cu_seqlens[-1].item()\n    \n        qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype)\n    \n        qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d)\n    \n        qkv.requires_grad = True\n    \n        if b < 4:\n            ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)\n        else:\n            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)\n        ctx = ctx.view(b,s,h,d)\n    \n        ctx_ref = py_mha(qkv, amask, b,s,h,d)\n        self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))\n    \n        labels = torch.randn_like(ctx_ref)\n        diff = ctx_ref - labels\n        l = (diff * diff).sum() / b\n        l.backward()\n    \n        dw = ctx_ref.grad.permute(0,2,1,3) \n    \n        dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()\n    \n        if b < 4:\n            dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)\n        else:\n            dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)\n        \n        dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)\n    \n        self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))\n\n    def test_128(self):\n        self.run_test(128, 32)\n\n    def test_256(self):\n        self.run_test(256, 32)\n\n    def test_384(self):\n        self.run_test(384, 32)\n\n    def test_512(self):\n        self.run_test(512, 32)\n        self.run_test(512, 2)\n        self.run_test(512, 3)\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/layer_norm/test_fast_layer_norm.py",
    "content": "import torch\nimport unittest\nimport numpy as np\n\nimport torch.nn.functional as F\n\nfrom apex.contrib.layer_norm import FastLayerNorm\n\nimport fast_layer_norm as fln\n\n\nclass GPUTimer:\n    def __init__(self, stream):\n        self.start_ = torch.cuda.Event(enable_timing=True)\n        self.stop_ = torch.cuda.Event(enable_timing=True)\n        self.stream_ = stream\n    def start(self):\n        self.stream_.record_event(self.start_)\n    def stop(self):\n        self.stream_.record_event(self.stop_)\n    def sync(self):\n        self.stream_.synchronize()\n    def millis(self):\n        return self.start_.elapsed_time(self.stop_)\n\ndef size_in_bytes(t):\n    return torch.numel(t) * t.element_size()\ndef abs_err(x, y):\n    xf = x.float()\n    yf = y.float()\n    return ((xf-yf).abs().sum() / yf.abs().sum()).item()\n\n\n\nclass TestFastLayerNorm(unittest.TestCase):\n    \n    def setUp(self, seed=1234):\n        seed = 1234\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def test_ln_fp32(self):\n        self.run_test_layer_norm(torch.float32, atol=1e-5)\n    def test_ln_fp16(self):\n        self.run_test_layer_norm(torch.float16, atol=1e-2, rtol=1e-3)\n\n    def run_test_layer_norm(self, dtype, atol, rtol=1e-5):\n        device = torch.device('cuda')\n        s = 512\n        b = 32\n        hidden_size = 1024\n        epsilon = 1e-5\n\n        x = torch.randn((s,b,hidden_size), dtype=dtype, device=device)  \n        beta = torch.randn(hidden_size, dtype=dtype, device=device)  \n        gamma = torch.randn(hidden_size, dtype=dtype, device=device)\n        x.requires_grad = True\n        beta.requires_grad = True\n        gamma.requires_grad = True\n\n        x2 = x.clone().detach()\n        beta2 = beta.clone().detach()\n        gamma2 = gamma.clone().detach()\n        x2.requires_grad = True\n        beta2.requires_grad = True\n        gamma2.requires_grad = True\n               \n        dummy_label = torch.randn_like(x)\n\n        y = F.layer_norm(x, [hidden_size], gamma, beta, epsilon)\n\n        diff = y-dummy_label\n        l = (diff * diff).sum() / b\n        l.backward()\n\n        fln = FastLayerNorm(hidden_size).cuda()\n        fln.load_state_dict({'bias': beta2, 'weight':gamma2})\n        if dtype == torch.float16:\n            fln = fln.half()\n\n        y2 = fln(x2)\n        diff2 = (y2 - dummy_label)\n        l2 = (diff2 * diff2).sum() / b\n\n        l2.backward()\n\n        self.assertTrue(torch.allclose(y2, y, atol=atol, rtol=rtol))\n        self.assertTrue(torch.allclose(x2.grad, x.grad, atol=atol,rtol=rtol))\n        self.assertTrue(torch.allclose(fln.bias.grad, beta.grad, atol=atol, rtol=rtol))\n        self.assertTrue(torch.allclose(fln.weight.grad, gamma.grad, atol=atol, rtol=rtol))\n    \n\n\n    def test_performance(self):\n        print()\n        runs = 1000\n        device = torch.device('cuda')\n        dtype =torch.float16\n        s = 512\n        b = 32\n        hidden_size = 1024\n        epsilon = 1e-5\n\n        x = torch.randn((s*b,hidden_size), dtype=dtype, device=device)  \n        beta = torch.randn(hidden_size, dtype=dtype, device=device)  \n        gamma = torch.randn(hidden_size, dtype=dtype, device=device)\n        dy = torch.randn_like(x)\n \n\n        stream = torch.cuda.Stream()\n        with torch.cuda.stream(stream):\n\n            timer = GPUTimer(stream)\n\n            #warmup\n            for r in range(runs):\n                y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)\n           \n           \n            timer.start()\n            for r in range(runs):\n                y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)\n            timer.stop()\n            timer.sync()\n\n            total_bytes_fwd = (size_in_bytes(x) \n                             + size_in_bytes(y) \n                             + size_in_bytes(gamma) \n                             + size_in_bytes(beta) \n                             + size_in_bytes(mu) \n                             + size_in_bytes(rsigma)\n                             )\n\n            ms_fwd = timer.millis() / runs\n            print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd ))\n         \n\n            timer.start()\n            for r in range(runs):\n                dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma)\n            timer.stop()\n            timer.sync()\n\n            total_bytes_bwd = (size_in_bytes(x) \n                             + size_in_bytes(dx)\n                             + size_in_bytes(dy) \n                             + size_in_bytes(gamma) \n                             + size_in_bytes(dgamma)  \n                             + size_in_bytes(dbeta)  \n                             + size_in_bytes(mu) \n                             + size_in_bytes(rsigma)\n                             )\n\n\n            ms_bwd = timer.millis() / runs\n            print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd ))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nclass EncdecMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=False, \n                                             impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=False, \n                                             impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_encdec_multihead_attn(self) :\n        grads         = torch.randn_like(self.tst_inputs_q)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n    \n    def test_encdec_multihead_attn_time_mask(self) :\n        grads          = torch.randn_like(self.tst_inputs_q)\n        time_mask_byte = torch.triu(torch.ones(self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        time_mask_bool = time_mask_byte.to(torch.bool)\n        \n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_bool,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_byte,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n    \n    def test_encdec_multihead_attn_pad_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs_q)\n        pad_mask_byte = torch.tril(torch.ones(self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        pad_mask_bool = pad_mask_byte.to(torch.bool)\n        \n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=pad_mask_bool, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=pad_mask_byte, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nclass EncdecMultiheadAttnNormAddTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=True, \n                                             impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, \n                                             self.heads, \n                                             dropout=self.dropout_prob, \n                                             bias=False, \n                                             include_norm_add=True, \n                                             impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                        dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_encdec_multihead_attn_norm_add(self) :\n        grads         = torch.randn_like(self.tst_inputs_q)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, \n                                               self.ref_inputs_k, \n                                               self.ref_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, \n                                               self.tst_inputs_k, \n                                               self.tst_inputs_k,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs_q,  self.tst_inputs_q,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.ref_inputs_k,  self.tst_inputs_k,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\n\nclass SelfMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=True, \n                                           include_norm_add=False, \n                                           separate_qkv_params=True, \n                                           mask_additive=True, \n                                           impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=True, \n                                           include_norm_add=False, \n                                           separate_qkv_params=True, \n                                           mask_additive=True, \n                                           impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    \n    def test_self_multihead_attn_additive_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n        mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=mask, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=mask, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py",
    "content": "import torch\nimport unittest\nimport torch.nn.functional as F\nfrom apex.contrib.multihead_attn import fast_mask_softmax_dropout_func\n\nclass FusedSoftmaxTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.mask = (torch.randn(self.sequences,self.seq_length)>0).cuda()\n        self.mask = self.mask.half()*-10000\n        self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n        \n        self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)\n\n    def test_fused_softmax(self) :\n        grads = torch.randn_like(self.tst_inputs)\n        y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length)\n        y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2)\n        y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length) \n        y_ref = F.softmax(y_ref, dim=-1)\n        y_ref = torch._fused_dropout(y_ref, 1.0)    \n   \n        y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0)        \n        y_ref[0].backward(grads)\n        y_tst.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(y_ref[0], y_tst, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/multihead_attn/test_self_multihead_attn.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\n\nclass SelfMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=False, \n                                           impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=False, \n                                           impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_self_multihead_attn(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\n    def test_self_multihead_attn_time_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n        time_mask_byte= torch.triu(torch.ones(self.tst_inputs.size(0), self.tst_inputs.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        time_mask_bool= time_mask_byte.to(torch.bool)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_bool,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=time_mask_byte,\n                                               is_training=True)\n\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n    \n    def test_self_multihead_attn_pad_mask(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n        pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device(\"cuda\"), dtype=torch.uint8), 1)\n        pad_mask_bool = pad_mask_byte.to(torch.bool)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=pad_mask_bool, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=pad_mask_byte, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py",
    "content": "import torch\n\nimport unittest\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\n\nclass SelfMultiheadAttnNormAddTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length   = 80\n        self.sequences    = 10\n        self.hidden_dim   = 1024\n        self.heads        = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=True, \n                                           impl='default')\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        \n        self.tst_layer = SelfMultiheadAttn(self.hidden_dim, \n                                           self.heads, \n                                           dropout=self.dropout_prob, \n                                           bias=False, \n                                           include_norm_add=True, \n                                           impl='fast')\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n        \n        self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, \n                                      dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    def test_self_multihead_attn_norm_add(self) :\n        grads         = torch.randn_like(self.tst_inputs)\n\n        ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, \n                                               self.ref_inputs, \n                                               self.ref_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n\n        tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, \n                                               self.tst_inputs, \n                                               self.tst_inputs,\n                                               key_padding_mask=None, \n                                               need_weights=False, \n                                               attn_mask=None,\n                                               is_training=True)\n        \n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        self.assertTrue(torch.allclose(self.ref_inputs,  self.tst_inputs,  atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))\n        self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/test_label_smoothing.py",
    "content": "import torch\nfrom apex.contrib import xentropy as label_smoothing\nimport unittest\n\nimport warnings\nimport random\nimport numpy as np\nimport time\n\ndef label_smoothing_raw(x, target, padding_idx, smoothing):\n    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)\n\n    non_pad_mask = (target != padding_idx)\n    nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))\n    nll_loss = nll_loss.squeeze(1)[non_pad_mask]\n    smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask]\n    loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss\n    return loss\n\ndef label_smoothing_opt_1(x, target, padding_idx, smoothing):\n    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)\n\n    pad_mask = (target == padding_idx)\n    ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)\n    smooth_loss = logprobs.mean(dim=-1)\n    loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss\n    loss.masked_fill_(pad_mask, 0)\n    return loss\n\nclass LabelSmoothingTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        # Set pytorch print precision\n        torch.set_printoptions(precision=10)\n\n    def gen_test_inputs(self, N, T, H, smoothing, padding_idx):\n        logits = torch.randn((N*T, H), dtype=torch.half, device='cuda',\n            requires_grad=True)\n        labels = torch.randint(0, H, [N*T], device='cuda')\n        for i in random.sample(range(N*T), N*T//6):\n            labels[i] = padding_idx\n        half_to_float = (logits.dtype == torch.half)\n\n        return logits, labels, half_to_float\n\n    def print_max_diff_elem(self, ref, tst):\n        ref, tst = ref.flatten(), tst.flatten()\n        diff = (ref - tst).abs().max()\n        idx = (ref - tst).abs().argmax()\n        print(\"Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}\".format(\n            idx, diff, ref[idx], tst[idx]))\n\n    def test_label_smoothing_function(self):\n        # Set label smoothing configuration\n        smoothing, padding_idx = 0.1, 0\n        N, T, H = 128, 74, 32320\n        iters = 10\n        loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply\n\n        for i in range(iters):\n            logits, labels, half_to_float = self.gen_test_inputs(\n                N, T, H, smoothing, padding_idx)\n    \n            # Run original softmax cross entropy with label smoothing\n            logits.grad = None\n            losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)\n            loss = losses.sum()\n            loss.backward()\n            \n            ref_loss = loss.clone().detach()\n            ref_grad = logits.grad.clone().detach()\n\n            # Run optimized softmax cross entropy with label smoothing\n            logits.grad = None\n            losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)\n            loss = losses.sum()\n            loss.backward()\n\n            val_loss = loss.clone().detach()\n            val_grad = logits.grad.clone().detach()\n\n            # Validate\n            self.print_max_diff_elem(ref_grad, val_grad)\n            self.assertTrue(torch.allclose(ref_loss, val_loss, atol=1e-5, rtol=1e-5))\n            self.assertTrue(torch.allclose(ref_grad, val_grad, atol=1e-5, rtol=1e-5))\n\n    def test_label_smoothing_perf(self):\n        # Set label smoothing configuration\n        smoothing, padding_idx = 0.1, 0\n        N, T, H = 128, 74, 32320\n        iters = 1000\n        loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply\n        print()\n\n        logits, labels, half_to_float = self.gen_test_inputs(\n            N, T, H, smoothing, padding_idx)\n    \n        # Run original softmax cross entropy with label smoothing\n        torch.cuda.synchronize()\n        ts = time.time()\n        for i in range(iters):\n            logits.grad = None\n            losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)\n            loss = losses.sum() / N\n            loss.backward()\n        torch.cuda.synchronize()\n        print(\"Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}\".format(\n            time.time() - ts, iters, logits.grad.norm()))\n            \n        # Run optimized softmax cross entropy with label smoothing\n        torch.cuda.synchronize()\n        ts = time.time()\n        for i in range(iters):\n            logits.grad = None\n            losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)\n            loss = losses.sum() / N\n            loss.backward()\n        torch.cuda.synchronize()\n        print(\"Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}\".format(\n            time.time() - ts, iters, logits.grad.norm()))\n\nif __name__ == '__main__':\n    unittest.main()\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/transducer/test_transducer_joint.py",
    "content": "import torch\nimport unittest\nfrom apex.contrib.transducer import TransducerJoint\nimport transducer_ref\n\nclass TransducerJointTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def gen_input(self, for_vector_kernel):\n        self.B = 4\n        T_min = 51\n        T_max = 101\n        U_min = 12\n        U_max = 25\n        if for_vector_kernel:\n            H = 512\n        else:\n            H = 509\n        dtype = torch.float16\n        device = \"cuda\"\n\n        self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)\n        self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)\n        self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)\n        self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) \n        self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)\n        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max\n        self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max\n        self.dropout_prob = 0.5\n\n        # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by \n        # the loss function\n        for b in range(self.B):\n            self.h_grad[b, self.f_len[b]:, :, :] = 0\n            self.h_grad[b, :, self.g_len[b]:, :] = 0\n        self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)\n        \n\n    def _pack(self, x, f_len, g_len):\n        B = x.size(0)\n        list_x = []\n        for b in range(B):\n            list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]\n            x_row = torch.cat(list_x_row)\n            list_x.append(x_row)\n        x_packed = torch.cat(list_x).data.clone()\n        x_packed.requires_grad = True\n        batch_offset = torch.cumsum(f_len * g_len, dim=0)\n        return x_packed\n\n    def _unpack(self, x, f_len, g_len):\n        batch_offset = torch.cumsum(f_len * g_len, dim=0)\n        x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)\n        B = self.h_grad.size(0)\n        H = self.h_grad.size(-1)\n        for b in range(B):\n            my_batch_offset = 0 if b == 0 else batch_offset[b-1]\n            my_f_len = f_len[b]\n            my_g_len = g_len[b]\n            for t in range(my_f_len):\n                x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : \n                                                my_batch_offset + t*my_g_len + my_g_len]\n        return x_unpacked\n        \n    def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):\n        self.gen_input(for_vector_kernel=for_vector_kernel)\n        # Generate reference\n        f_ref = self.f_tst.data.clone()\n        g_ref = self.g_tst.data.clone()\n        f_ref.requires_grad = True\n        g_ref.requires_grad = True\n        \n        my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, \n                                    dropout_prob=self.dropout_prob, probe_mask=True)\n        if not pack_output:\n            h_tst = my_joint(   f=self.f_tst, \n                                g=self.g_tst, \n                                f_len=self.f_len, \n                                g_len=self.g_len)\n            h_tst.backward(self.h_grad)\n            if dropout:\n                mask = my_joint.mask_probe[0]\n        else:\n            batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)\n            h_tst = my_joint(   f=self.f_tst, \n                                g=self.g_tst, \n                                f_len=self.f_len, \n                                g_len=self.g_len, \n                                batch_offset=batch_offset, \n                                packed_batch=batch_offset[-1])\n            h_tst.backward(self.h_grad_packed)\n            if dropout:\n                mask_packed = my_joint.mask_probe[0]\n                mask = self._unpack(mask_packed, self.f_len, self.g_len)\n\n        # reference\n        h_ref, f_grad_ref, g_grad_ref \\\n            = transducer_ref.transducer_joint_reference(f=f_ref, \n                                                        g=g_ref, \n                                                        h_grad=self.h_grad, \n                                                        f_len=self.f_len, \n                                                        g_len=self.g_len, \n                                                        pack_output=pack_output,\n                                                        relu=relu,\n                                                        dropout=dropout,\n                                                        dropout_prob=self.dropout_prob,\n                                                        mask=mask if dropout else None)\n        \n        f_grad_tst = self.f_tst.grad\n        g_grad_tst = self.g_tst.grad\n        \n        self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))\n\n    def test_transducer_joint(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)\n\n    def test_transducer_joint_vec(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)\n\n    def test_transducer_joint_pack(self):\n        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)\n\n    def test_transducer_joint_vec_pack(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)\n\n    def test_transducer_joint_relu(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)\n\n    def test_transducer_joint_vec_relu(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)\n\n    def test_transducer_joint_pack_relu(self):\n        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)\n\n    def test_transducer_joint_vec_pack_relu(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)\n\n    def test_transducer_joint_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)\n\n    def test_transducer_joint_vec_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)\n\n    def test_transducer_joint_pack_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)\n\n    def test_transducer_joint_vec_pack_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/transducer/test_transducer_loss.py",
    "content": "import torch\nimport unittest\nfrom apex.contrib.transducer import TransducerLoss\nimport transducer_ref\n\nclass TransducerLossTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def gen_input(self, scalar_t, for_vector_kernel):\n        self.B = 5\n        T_min = 23\n        T_max = 51\n        U_min = 12\n        U_max = 25\n        V = 16 if for_vector_kernel else 14\n        self.blank_idx = V - 1\n        device = \"cuda\"\n\n        self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, \n                                    device=device)\n        self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device)\n        self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) \n        self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device)\n        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max\n        self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1\n        self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)\n        # Generate reference\n        x_ref = self.x_tst.data.clone()\n        x_ref.requires_grad = True\n        loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0)\n        _, _, self.grad_ref, self.loss_ref \\\n            = transducer_ref.transducer_loss_reference( x=x_ref, \n                                                        label=self.y, \n                                                        f_len=self.f_len, \n                                                        y_len=self.y_len, \n                                                        blank_idx=self.blank_idx, \n                                                        loss_grad=loss_grad)\n\n    def _pack(self, x):\n        list_x = []\n        for b in range(self.B):\n            list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])]\n            x_row = torch.cat(list_x_row)\n            list_x.append(x_row)\n        x_packed = torch.cat(list_x).data.clone()\n        x_packed.requires_grad = True\n        batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0)\n        return x_packed, batch_offset\n\n    def _unpack(self, x):\n        x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1), \n                                    dtype=x.dtype, device=x.device)\n        for b in range(self.B):\n            my_batch_offset = 0 if b == 0 else self.batch_offset[b-1]\n            my_f_len = self.f_len[b]\n            my_g_len = self.y_len[b] + 1\n            for t in range(my_f_len):\n                for u in range(my_g_len):\n                    x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]\n        return x_unpacked\n\n    def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):\n        self.gen_input(scalar_t, for_vector_kernel)\n        my_loss = TransducerLoss(  fuse_softmax_backward=fuse_softmax_backward, \n                                    packed_input=packed_input) \n        if not packed_input:\n            loss_tst = my_loss( x=self.x_tst,\n                                label=self.y, \n                                f_len=self.f_len, \n                                y_len=self.y_len, \n                                blank_idx=self.blank_idx)\n            loss_tst.mean().backward() \n            grad_tst = self.x_tst.grad\n        else:\n            loss_tst = my_loss( x=self.x_tst_packed,\n                                label=self.y, \n                                f_len=self.f_len, \n                                y_len=self.y_len, \n                                blank_idx=self.blank_idx,\n                                batch_offset=self.batch_offset, \n                                max_f_len=max(self.f_len))\n            loss_tst.mean().backward()\n            grad_tst_packed = self.x_tst_packed.grad\n            grad_tst = self._unpack(grad_tst_packed)\n        \n        return loss_tst, grad_tst\n\n    def test_transducer_loss_fp32(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float32,\n                                                        fuse_softmax_backward=False,\n                                                        packed_input=False,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))\n\n    def test_transducer_loss_fp16(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=False,\n                                                        packed_input=False,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n    def test_transducer_loss_fp16_backward_fusion(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=True,\n                                                        packed_input=False,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n    def test_transducer_loss_fp16_backward_fusion_packed(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=True,\n                                                        packed_input=True,\n                                                        for_vector_kernel=False)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n    def test_transducer_loss_fp16_backward_fusion_packed_vec(self):\n        loss_tst, grad_tst = self.run_transducer_loss(  scalar_t=torch.float16,\n                                                        fuse_softmax_backward=True,\n                                                        packed_input=True,\n                                                        for_vector_kernel=True)\n        self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))\n        self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "KoSimCSE/apex/contrib/test/transducer/transducer_ref.py",
    "content": "import torch\nimport numpy as np\nimport pdb\n\ndef transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):\n    def log_sum_exp(a, b):\n        if (a >= b):\n            return a + torch.log(1 + torch.exp(b-a))\n        else:\n            return b + torch.log(1 + torch.exp(a-b))\n\n    def forward_alpha(x, label, f_len, y_len, blank_idx):\n        B, T, U, V = x.size()\n        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype\n        alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)\n        for b in range(B):\n            alpha[b, 0, 0] = 0\n            for t in range(1, f_len[b]):\n                alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]\n            for u in range(1, y_len[b]+1):\n                alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]\n            for t in range(1, f_len[b]):\n                for u in range(1, y_len[b]+1):\n                    curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]\n                    next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]\n                    alpha[b, t, u] = log_sum_exp(curr_, next_) \n        return alpha\n\n    def forward_beta(x, label, f_len, y_len, blank_idx):\n        B, T, U, V = x.shape\n        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype\n        beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)\n        for b in range(B):\n            beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]\n            for t in range(f_len[b]-2, -1, -1):\n                beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx] \n            for u in range(y_len[b]-1, -1, -1):\n                beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]\n            for t in range(f_len[b]-2, -1, -1):\n                for u in range(y_len[b]-1, -1, -1):\n                    curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx] \n                    next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]\n                    beta[b, t, u] = log_sum_exp(curr_, next_) \n        return beta\n\n    def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):\n        grad = torch.zeros_like(x)\n        B, T, U, V = x.size()\n        for b in range(B):\n            common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]\n            # next\n            for u in range(y_len[b]):\n                grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u] \n                                                        + beta[b, :f_len[b], u+1] \n                                                        + x[b, :f_len[b], u, label[b, u]])\n\n            # current\n            grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \\\n                = -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1] \n                    + beta[b, 1:f_len[b], :y_len[b]+1] \n                    + x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])\n\n            grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]\n                                                         + x[b, f_len[b]-1, y_len[b], blank_idx])\n     \n        return grad\n\n    x_log = torch.nn.functional.log_softmax(x, dim=-1)\n    alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)\n    beta = forward_beta(x_log, label, f_len, y_len, blank_idx)\n    grad = backward(x_log, label, f_len, y_len, alpha, beta, \n                        loss_grad, blank_idx)\n    x_log.backward(grad)\n    loss = -beta[:, 0, 0]\n    loss = loss.to(x.dtype)\n    return alpha, beta, x.grad, loss\n\n\ndef transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout, \n                                dropout_prob=0, mask=None):\n    if dropout and mask == None:\n        raise NotImplementedError(\"mask needs to supplied to test dropout.\")\n    B, T, H = f.size()\n    U = g.size(1)\n    f_expand = f.unsqueeze(dim=2)\n    g_expand = g.unsqueeze(dim=1)\n    h = f_expand + g_expand\n    if relu:\n        h = torch.nn.functional.relu(h)\n    if dropout:\n        h *= mask\n        scale = 1/(1-dropout_prob)\n        h *= scale\n    h.backward(h_grad)\n\n    if pack_output == False:\n        # intentionally set don't-care region to -1 to test if transducer joint\n        # write these regions to avoid NaN and inf\n        for b in range(B):\n            h[b, f_len[b]:] = -1\n            h[b, :, g_len[b]:] = -1\n\n        return h, f.grad, g.grad \n\n    # packing\n    list_to_pack = []\n    for b in range(B):\n        list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))\n    h_packed = torch.cat(list_to_pack)\n    return h_packed, f.grad, g.grad\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/transducer/__init__.py",
    "content": "from .transducer import TransducerJoint\nfrom .transducer import TransducerLoss"
  },
  {
    "path": "KoSimCSE/apex/contrib/transducer/transducer.py",
    "content": "import torch\nimport transducer_loss_cuda\nimport transducer_joint_cuda\n\nclass TransducerJoint(torch.nn.Module):\n    \"\"\"Transducer joint\n    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural \n    Networks\n\n    Arguments:\n        pack_output (bool, optional): whether to pack the output in a compact form with don't-care \n        data being removed. (default: False)\n        relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1  \n        (default: False)\n        dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1  \n        (default: False)\n        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. \n            (default: 1)\n        fwd_tile_size (int, optional): tile size used in forward operation. This argument will be \n        ignored if opt != 1. (default: 4) \n        dropout_prob (float, optional): dropout probability. (default: 0.0)\n        probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout\n        operation. When this argument is set to True, the mask can be accessed through \n        self.mask_probe. (default: false)\n    \"\"\"\n\n    def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, \n                    dropout_prob=0, probe_mask=False):\n        super(TransducerJoint, self).__init__() \n        self.pack_output = pack_output\n        self.relu = relu\n        self.dropout = dropout\n        self.dropout_prob = dropout_prob\n        self.opt = opt\n        self.fwd_tile_size = fwd_tile_size\n        self.dummy_batch_offset = torch.empty(0)\n        masked = self.relu or self.dropout\n        self.mask_probe = [] if masked and probe_mask else None\n        if masked and opt != 1:\n            raise NotImplementedError(\"ReLU and dropout fusion is only supported with opt=1\")\n\n\n    def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):\n        \"\"\"Forward operation of transducer joint\n\n        Arguments:\n            f (tensor): transcription vector from encode block of shape (B, T, H).\n            g (tensor): prediction vector form predict block of shape (B, U, H).\n            f_len (tensor): length of transcription vector for each batch.\n            g_len (tensor): length of prediction vector minus 1 for each batch.\n            batch_offset (tensor, optional): tensor containing the offset of each batch\n                in the results. For example, batch offset can be obtained from: \n                batch_offset = torch.cumsum(f_len*g_len, dim=0)\n                This argument is required if pack_output == True, and is ignored if \n                pack_output == False. (default: None)\n            packed_batch (int, optional): the batch size after packing. This argument is \n                ignored if pack_output == False. (default: 0)\n        \"\"\"\n        my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset\n        if self.pack_output and (batch_offset is None or packed_batch == 0):\n            raise Exception(\"Please specify batch_offset and packed_batch when packing is enabled\")\n        dropout =  self.dropout and self.training    # only dropout for training\n        return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, \n                                            my_batch_offset, packed_batch, self.opt, \n                                            self.fwd_tile_size, self.dropout_prob, self.mask_probe)\n\n\nclass TransducerLoss(torch.nn.Module):\n    \"\"\"Transducer loss\n    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural \n    Networks\n\n    Arguments:\n        fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with\n            softmax. (default: True)\n        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized \n            algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)\n        packed_input (bool, optional): whether to pack the output in a compact form with don't-care \n        data being removed. (default: False)\n    \"\"\"\n    def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):\n        super(TransducerLoss, self).__init__() \n        self.fuse_softmax_backward = fuse_softmax_backward\n        self.opt = opt\n        self.packed_input = packed_input\n        self.dummy_batch_offset = torch.empty(0)\n\n\n    def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None, \n                debug_list=None):\n        \"\"\"Forward operation of transducer joint\n\n        Arguments:\n            x (tensor): input tensor to the loss function with a shape of (B, T, U, H).\n            label (tensor): labels for the input data.\n            f_len (tensor): lengths of the inputs in the time dimension for each batch.\n            y_len (tensor): lengths of the labels for each batch.\n            blank_idx (int): index for the null symbol.\n            batch_offset (tensor, optional): tensor containing the offset of each batch\n                in the input. For example, batch offset can be obtained from: \n                batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)\n                This argument is required if packed_input == True, and is ignored if \n                packed_input == False. (default: None)\n            max_f_len (int, optional): maximum length of the input in the time dimension.\n                For example, it can be obtained as \n                max_f_len = max(f_len)\n                This argument is required if packed_input == True, and is ignored if \n                packed_input == False. (default: None)\n                (default: None)\n            debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated \n                in the forward operation will be attached to this list for debug purpose. \n                (default: None)\n        \"\"\"\n        if self.packed_input:\n            if batch_offset is None or max_f_len is None:\n                raise Exception(\"Please specify batch_offset and max_f_len when packing is \\\n                                    enabled\") \n            my_batch_offset = batch_offset\n            my_max_f_len = max_f_len\n        else:\n            my_batch_offset = self.dummy_batch_offset\n            my_max_f_len = x.size(1)\n        return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len, \n                                            blank_idx, self.fuse_softmax_backward, debug_list, \n                                            self.opt, self.packed_input)\n\nclass TransducerLossFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, \n                fuse_softmax_backward, debug_list, opt, packed_input):\n        if fuse_softmax_backward == False:\n            with torch.enable_grad():\n                x = torch.nn.functional.log_softmax(x, dim=-1)\n        else:\n            x = torch.nn.functional.log_softmax(x, dim=-1)\n        alpha, beta, loss = transducer_loss_cuda.forward(   x, label, f_len, y_len, batch_offset, \n                                                            max_f_len, blank_idx, opt, packed_input)\n        if debug_list == []:\n            debug_list += [alpha, beta]\n        ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)\n        ctx.blank_idx = blank_idx\n        ctx.fuse_softmax_backward = fuse_softmax_backward\n        ctx.opt = opt\n        ctx.packed_input = packed_input\n        ctx.max_f_len = max_f_len\n        return loss\n\n    @staticmethod\n    def backward(ctx, loss_grad):\n        x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors\n        x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label, \n                                                batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt, \n                                                ctx.fuse_softmax_backward, ctx.packed_input)\n        if ctx.fuse_softmax_backward == False:\n            x_grad = x.backward(x_grad)\n        return x_grad, None, None, None, None, None, None, None, None, None, None\n\nclass TransducerJointFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, \n                opt, fwd_tile_size, dropout_prob, mask_probe):\n        h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, \n                                            pack_output, relu, dropout, dropout_prob, fwd_tile_size)\n        masked = relu or dropout\n        if masked:\n            ctx.save_for_backward(h[1], f_len, g_len, batch_offset)\n            if mask_probe is not None:\n                mask_probe.append(h[1])\n        else:\n            ctx.save_for_backward(f_len, g_len, batch_offset)\n\n        ctx.pack_output = pack_output\n        ctx.masked = relu or dropout\n        ctx.max_f_len = f.size(1)\n        ctx.max_g_len = g.size(1)\n        ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1\n        return h[0]\n\n    @staticmethod\n    def backward(ctx, loss_grad):\n        if ctx.masked:\n            mask, f_len, g_len, batch_offset = ctx.saved_tensors\n            inp = [loss_grad, mask]\n        else:\n            f_len, g_len, batch_offset = ctx.saved_tensors\n            inp = [loss_grad]\n\n        f_grad, g_grad = transducer_joint_cuda.backward(    inp, f_len, g_len, batch_offset, \n                                                            ctx.max_f_len, ctx.max_g_len, \n                                                            ctx.pack_output, ctx.scale)\n\n        return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \\\n                None, None, None\n\n\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/xentropy/__init__.py",
    "content": "try:\n    import torch\n    import xentropy_cuda\n    from .softmax_xentropy import SoftmaxCrossEntropyLoss\n    del torch\n    del xentropy_cuda\n    del softmax_xentropy\nexcept ImportError as err:\n    print(\"apex was installed without --xentropy flag, contrib.xentropy is not available\")\n"
  },
  {
    "path": "KoSimCSE/apex/contrib/xentropy/softmax_xentropy.py",
    "content": "import torch\nimport xentropy_cuda\n\nclass SoftmaxCrossEntropyLoss(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):\n        losses, max_log_sum_exp = xentropy_cuda.forward(\n            logits, labels, smoothing, half_to_float)\n        losses.masked_fill_(labels==padding_idx, 0)\n\n        ctx.save_for_backward(logits, max_log_sum_exp, labels,\n            torch.FloatTensor([smoothing]),\n            torch.LongTensor([padding_idx]))\n\n        return losses\n\n    @staticmethod\n    def backward(ctx, grad_loss):\n        logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors\n\n        if not grad_loss.is_contiguous():\n            grad_loss = grad_loss.contiguous()\n        grad_loss.masked_fill_(labels==padding_idx.item(), 0)\n        grad_logits = xentropy_cuda.backward(\n            grad_loss.contiguous(), logits, max_log_sum_exp,\n            labels, smoothing.item())\n\n        return grad_logits, None, None, None, None\n"
  },
  {
    "path": "KoSimCSE/apex/fp16_utils/README.md",
    "content": "fp16_optimizer.py contains `FP16_Optimizer`, a Python class designed to wrap an existing Pytorch optimizer and automatically enable master parameters and loss scaling in a manner transparent to the user.  To use `FP16_Optimizer`, only two lines of one's Python model need to change.\n\n#### [FP16_Optimizer API documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling)\n\n#### [Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple)\n\n#### [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)\n\n#### [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model)\n\n\nfp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses.  \n\n#### [Manual management documentation](https://nvidia.github.io/apex/fp16_utils.html#manual-master-parameter-management)\n\nThe [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) and [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) directories also contain `main.py` files that demonstrate manual management of master parameters and static loss scaling.  These examples illustrate what sort of operations `FP16_Optimizer` is performing automatically.\n"
  },
  {
    "path": "KoSimCSE/apex/fp16_utils/__init__.py",
    "content": "from .fp16util import (\n    BN_convert_float,\n    network_to_half,\n    prep_param_lists,\n    model_grads_to_master_grads,\n    master_params_to_model_params,\n    tofp16,\n    to_python_float,\n    clip_grad_norm,\n    convert_module,\n    convert_network,\n    FP16Model,\n)\n\nfrom .fp16_optimizer import FP16_Optimizer\nfrom .loss_scaler import LossScaler, DynamicLossScaler\n"
  },
  {
    "path": "KoSimCSE/apex/fp16_utils/fp16_optimizer.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.nn.parameter import Parameter\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\n\nfrom ..amp._amp_state import _amp_state, maybe_print\nfrom ..amp.scaler import LossScaler\nfrom ..multi_tensor_apply import multi_tensor_applier\nfrom .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm\n\n# TODO:  Update overflow check + downscale to use Carl's fused kernel.\nclass FP16_Optimizer(object):\n    def __init__(self, \n                 init_optimizer, \n                 static_loss_scale=1.0, \n                 dynamic_loss_scale=False,\n                 dynamic_loss_args=None,\n                 verbose=True):\n        print(\"Warning:  FP16_Optimizer is deprecated and dangerous, and will be deleted soon.  \"\n              \"If it still works, you're probably getting lucky.  \"\n              \"For mixed precision, use the documented API https://nvidia.github.io/apex/amp.html, with opt_level=O1.\")\n\n        if not torch.cuda.is_available:\n            raise SystemError(\"Cannot use fp16 without CUDA.\")\n\n        self.verbose = verbose\n\n        self.optimizer = init_optimizer\n        # init_state_dict sets up an alternative way to cast per-param state tensors.\n        # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.\n        # init_state_dict = init_optimizer.state_dict()\n\n        self.fp16_groups = []\n        self.fp32_from_fp16_groups = []\n        self.fp32_from_fp32_groups = []\n        for i, param_group in enumerate(self.optimizer.param_groups):\n            self.maybe_print(\"FP16_Optimizer processing param group {}:\".format(i))\n            fp16_params_this_group = []\n            fp32_params_this_group = []\n            fp32_from_fp16_params_this_group = []\n            for i, param in enumerate(param_group['params']):\n                if param.requires_grad:\n                    if param.type() == 'torch.cuda.HalfTensor':\n                        self.maybe_print(\"FP16_Optimizer received torch.cuda.HalfTensor with {}\"\n                                         .format(param.size()))\n                        fp16_params_this_group.append(param)\n                        master_param = param.detach().clone().float()\n                        master_param.requires_grad = True\n                        param_group['params'][i] = master_param\n                        fp32_from_fp16_params_this_group.append(master_param)\n                        # Reset existing state dict key to the new master param.\n                        # We still need to recast per-param state tensors, if any, to FP32.\n                        if param in self.optimizer.state:\n                           self.optimizer.state[master_param] = self.optimizer.state.pop(param) \n                    elif param.type() == 'torch.cuda.FloatTensor':\n                        self.maybe_print(\"FP16_Optimizer received torch.cuda.FloatTensor with {}\"\n                                         .format(param.size()))\n                        fp32_params_this_group.append(param)\n                        param_group['params'][i] = param\n                    else:\n                        raise TypeError(\"Wrapped parameters must be either \"\n                                        \"torch.cuda.FloatTensor or torch.cuda.HalfTensor. \"  \n                                        \"Received {}\".format(param.type()))\n            \n            self.fp16_groups.append(fp16_params_this_group)\n            self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)\n            self.fp32_from_fp32_groups.append(fp32_params_this_group)\n\n        self.all_fp16_params = []\n        for group in self.fp16_groups:\n            self.all_fp16_params += group\n\n        self.all_fp32_from_fp16_params = []\n        for group in self.fp32_from_fp16_groups:\n            self.all_fp32_from_fp16_params += group\n\n        self.all_fp32_from_fp32_params = []\n        for group in self.fp32_from_fp32_groups:\n            self.all_fp32_from_fp32_params += group\n\n        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors\n        self.optimizer.load_state_dict(self.optimizer.state_dict())\n        # alternative way to cast per-param state tensors:\n        # self.optimizer.load_state_dict(init_state_dict)\n\n        if dynamic_loss_scale:\n            self.dynamic_loss_scale = True\n            if dynamic_loss_args is not None:\n                self.loss_scaler = LossScaler(\"dynamic\", **dynamic_loss_args)\n            else:\n                self.loss_scaler = LossScaler(\"dynamic\")\n        else:\n            self.dynamic_loss_scale = False\n            self.loss_scaler = LossScaler(static_loss_scale)\n\n        self.overflow = False\n        self.first_closure_call_this_step = True\n\n        self.clip_grad_norm = clip_grad_norm\n\n        # TODO:  Centralize exposure and import error checking for the C backend.\n        if multi_tensor_applier.available:\n            import amp_C\n            self.multi_tensor_scale = amp_C.multi_tensor_scale\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0]);\n\n    # Having self.maybe_print distinct from _amp_state.maybe_print is another artifact\n    # of having to support FP16_Optimizer separately, for the time being.\n    def maybe_print(self, msg):\n        if self.verbose:\n            print(msg)\n            \n    def __getstate__(self):\n        raise RuntimeError(\"FP16_Optimizer should be serialized using state_dict().\")\n\n    def __setstate__(self, state):\n        raise RuntimeError(\"FP16_Optimizer should be deserialized using load_state_dict().\")\n\n    def zero_grad(self, set_grads_to_None=False):\n        \"\"\"\n        Zero fp32 and fp16 parameter grads.\n        \"\"\"\n        # In principle, only the .grad attributes of the model params need to be zeroed,\n        # because gradients are copied into the FP32 master params.  However, we zero\n        # all gradients owned by the optimizer, just to be safe:\n        for group in self.optimizer.param_groups:\n             for p in group['params']:\n                 if set_grads_to_None:\n                     p.grad = None\n                 else:\n                     if p.grad is not None:\n                         p.grad.detach_()\n                         p.grad.zero_()\n\n        # Zero fp16 gradients owned by the model:\n        for fp16_group in self.fp16_groups:\n            for param in fp16_group:\n                if set_grads_to_None:\n                    param.grad = None\n                else:\n                    if param.grad is not None:\n                        param.grad.detach_() # as in torch.optim.optimizer.zero_grad()\n                        param.grad.zero_()\n\n    # Should not be used anymore.\n    # def _check_overflow(self):\n    #     params = []\n    #     for group in self.fp16_groups:\n    #         for param in group:\n    #             params.append(param)\n    #     for group in self.fp32_from_fp32_groups:\n    #         for param in group:\n    #             params.append(param)\n    #     self.overflow = self.loss_scaler.has_overflow(params)\n\n    # def _update_scale(self, has_overflow=False):\n    #     self.loss_scaler.update_scale(has_overflow)\n\n    def _master_params_to_model_params(self):\n        if multi_tensor_applier.available:\n            if len(self.all_fp16_params) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_scale,\n                    self._dummy_overflow_buf,\n                    [self.all_fp32_from_fp16_params, self.all_fp16_params],\n                    1.0)\n        else:\n            for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):\n                master_params_to_model_params(fp16_group, fp32_from_fp16_group)\n\n    # To consider:  Integrate distributed with this wrapper by registering a hook on each variable\n    # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.\n    # def _model_grads_to_master_grads(self):\n    #     for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):\n    #         model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)\n\n    # def _downscale_master(self):\n    #     if self.loss_scale != 1.0:\n    #         for group in self.optimizer.param_groups:\n    #             for param in group['params']:\n    #                 if param.grad is not None:\n    #                     param.grad.data.mul_(1./self.loss_scale)\n\n    def clip_master_grads(self, max_norm, norm_type=2):\n        \"\"\"\n        Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.\n\n        Args:\n            max_norm (float or int): max norm of the gradients\n            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n                infinity norm.\n\n        Returns:\n            Total norm of the current fp32 gradients (viewed as a single vector).\n\n        .. warning::\n            Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).\n        \"\"\"\n        if not self.overflow:\n            fp32_params = []\n            for param_group in self.optimizer.param_groups:\n                for param in param_group['params']:\n                    fp32_params.append(param)\n            return self.clip_grad_norm(fp32_params, max_norm, norm_type)\n        else:\n            return -1\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.\n        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict\n        of the contained Pytorch optimizer.\n        Example::\n\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        state_dict = {}\n        state_dict['loss_scaler'] = self.loss_scaler\n        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale\n        state_dict['overflow'] = self.overflow\n        state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step\n        state_dict['optimizer_state_dict'] = self.optimizer.state_dict()\n        state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict(). \n        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, \n        whose parameters in turn came from ``model``, it is expected that the user \n        will call ``model.load_state_dict()`` before\n        ``fp16_optimizer_instance.load_state_dict()`` is called.\n\n        Example::\n\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # I think it should actually be ok to reload the optimizer before the model.\n        self.loss_scaler = state_dict['loss_scaler']\n        self.dynamic_loss_scale = state_dict['dynamic_loss_scale']\n        self.overflow = state_dict['overflow']\n        self.first_closure_call_this_step = state_dict['first_closure_call_this_step']\n        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])\n        # At this point, the optimizer's references to the model's fp32 parameters are up to date.\n        # The optimizer's hyperparameters and internal buffers are also up to date.  \n        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still\n        # out of date.  There are two options.  \n        # 1:  Refresh the master params from the model's fp16 params.  \n        # This requires less storage but incurs precision loss.\n        # 2:  Save and restore the fp32 master copies separately.\n        # We choose option 2.\n        # \n        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device \n        # of their associated parameters, because it's possible those buffers might not exist yet in \n        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been \n        # constructed in the same way as the one whose state_dict we are loading, the same master params\n        # are guaranteed to exist, so we can just copy_() from the saved master params.\n        for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):\n            for current, saved in zip(current_group, saved_group):\n                current.data.copy_(saved.data)\n\n    def step(self, closure=None): # could add clip option.\n        \"\"\"\n        If no closure is supplied, :attr:`step` should be called after \n        ``fp16_optimizer_obj.backward(loss)``.\n        :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to\n        :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params\n        originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run\n        another forward pass using their model.\n\n        If a closure is supplied, :attr:`step` may be called without a prior call to \n        :attr:`backward(loss)`.\n        This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.\n        However, the user should take care that any ``loss.backward()`` call within the closure\n        has been replaced by ``fp16_optimizer_obj.backward(loss)``.\n\n        Args:\n           closure (optional):  Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor.  closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.\n\n        Example with closure::\n\n            # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an \n            # existing pytorch optimizer.\n            for input, target in dataset:\n                def closure():\n                    optimizer.zero_grad()\n                    output = model(input)\n                    loss = loss_fn(output, target)\n                    # loss.backward() becomes:\n                    optimizer.backward(loss)\n                    return loss\n                optimizer.step(closure)\n\n        .. warning::\n            Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.\n\n        .. _`ordinary Pytorch optimizer use`:\n            http://pytorch.org/docs/master/optim.html#optimizer-step-closure\n        \"\"\"\n\n        scale = self.loss_scaler.loss_scale()\n        # To consider:  Should this be in step(), or update_master_grads?  It works either way,\n        # but I should make it consistent with the Amp control flow, which updates the scale\n        # during backward context manager exit.\n        # self._update_scale(self.overflow)\n\n        if self.overflow:\n            # Using _amp_state.maybe_print instead of self.print here is intentional.\n            maybe_print(\"Gradient overflow.  Skipping step, reducing \" +\n                \"loss scale to {}\".format(self.loss_scaler.loss_scale()))\n            return\n        \n        if closure is not None:\n            retval = self._step_with_closure(closure)\n        else:\n            # torch.cuda.nvtx.range_push(\"pytorch optimizer step\")\n            retval = self.optimizer.step()\n            # torch.cuda.nvtx.range_pop()\n\n        self._master_params_to_model_params()\n\n        return retval\n\n    def _step_with_closure(self, closure):\n        def wrapped_closure():\n            # helpful for debugging\n            # print(\"Calling wrapped_closure, first_closure_call_this_step = {}\"\n            #       .format(self.first_closure_call_this_step))\n            if self.first_closure_call_this_step:\n                # We expect that the fp16 params are initially fresh on entering self.step(),\n                # so _master_params_to_model_params() is unnecessary the first time wrapped_closure()\n                # is called within self.optimizer.step().\n                self.first_closure_call_this_step = False\n            else:\n                # If self.optimizer.step() internally calls wrapped_closure more than once,\n                # it may update the fp32 params after each call.  However, self.optimizer \n                # doesn't know about the fp16 params at all.  If the fp32 params get updated,\n                # we can't rely on self.optimizer to refresh the fp16 params.  We need\n                # to handle that manually:\n                self._master_params_to_model_params()\n            # Our API expects the user to give us ownership of the backward() call by\n            # replacing all calls to loss.backward() with optimizer.backward(loss).\n            # This requirement holds whether or not the call to backward() is made within a closure.\n            # If the user is properly calling optimizer.backward(loss) within \"closure,\" \n            # calling closure() here will give the fp32 master params fresh gradients\n            # for the optimizer to play with, so all wrapped_closure needs to do is call \n            # closure() and return the loss.\n            temp_loss = closure() \n            while(self.overflow):\n                scale = self.loss_scaler.loss_scale()\n                # self._update_scale(self.overflow) # now done at the end of backward\n                print(\"OVERFLOW within closure! Skipping step, reducing loss scale to {}\".format(\n                      self.loss_scaler.loss_scale()))\n                temp_loss = closure()\n            return temp_loss\n\n        retval = self.optimizer.step(wrapped_closure)\n\n        self.first_closure_call_this_step = True\n\n        return retval\n\n    def backward(self, loss, update_master_grads=True, retain_graph=False):\n        \"\"\" \n        :attr:`backward` performs the following conceptual steps:\n\n        1. fp32_loss = loss.float() (see first Note below)\n        2. scaled_loss = fp32_loss*loss_scale\n        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).\n        4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.\n        5. Finally, master grads are divided by loss_scale.\n\n        In this way, after :attr:`backward`, the master params have fresh gradients,\n        and :attr:`step` may be called.\n\n        .. note::\n            :attr:`backward` internally converts the loss to fp32 before applying the loss scale.\n            This provides some additional safety against overflow if the user has supplied an \n            fp16 loss value.  \n            However, for maximum overflow safety, the user should\n            compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to \n            :attr:`backward`.\n\n        .. warning::\n            The gradients found in a model's leaves after the call to \n            :attr:`backward` should not be regarded as valid in general, \n            because it's possible \n            they have been scaled (and in the case of dynamic loss scaling, \n            the scale factor may change over time).  \n            If the user wants to inspect gradients after a call to :attr:`backward`,  \n            only the master gradients should be regarded as valid.  These can be retrieved via\n            :attr:`inspect_master_grad_data()`.\n\n        Args:\n            loss:  The loss output by the user's model.  loss may be either float or half (but see first Note above).\n            update_master_grads (bool, optional, default=True):  Option to copy fp16 grads to fp32 grads on this call.  By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration.  If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.\n            retain_graph (bool, optional, default=False):  Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``.  If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).\n\n        Example::\n\n            # Ordinary operation:\n            optimizer.backward(loss)\n\n            # Naive operation with multiple losses (technically valid, but less efficient):\n            # fp32 grads will be correct after the second call,  but \n            # the first call incurs an unnecessary fp16->fp32 grad copy.\n            optimizer.backward(loss1)\n            optimizer.backward(loss2)\n\n            # More efficient way to handle multiple losses:\n            # The fp16->fp32 grad copy is delayed until fp16 grads from all \n            # losses have been accumulated.\n            optimizer.backward(loss1, update_master_grads=False)\n            optimizer.backward(loss2, update_master_grads=False)\n            optimizer.update_master_grads()\n        \"\"\" \n        # To consider:  try multiple backward passes using retain_grad=True to find \n        # a loss scale that works.  After you find a loss scale that works, do a final dummy\n        # backward pass with retain_graph=False to tear down the graph.  Doing this would avoid \n        # discarding the iteration,  but probably wouldn't improve overall efficiency.  \n        scaled_loss = loss.float()*self.loss_scaler.loss_scale()\n        scaled_loss.backward(retain_graph=retain_graph)\n        if update_master_grads:\n            self.update_master_grads()\n\n    def update_master_grads(self):\n        # torch.cuda.nvtx.range_push(\"update_master_grads\")\n        \"\"\"\n        Copy the ``.grad`` attribute from stored references to fp16 parameters to \n        the ``.grad`` attribute of the fp32 master parameters that are directly \n        updated by the optimizer.  :attr:`update_master_grads` only needs to be called if\n        ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.\n        \"\"\"\n        # if self.dynamic_loss_scale:\n        #     self._check_overflow()\n        #     if self.overflow: return\n        # self._model_grads_to_master_grads()\n        # self._downscale_master()\n        # Use the one-shot multi-tensor apply kernel\n        self.loss_scaler.clear_overflow_state()\n        if len(self.all_fp16_params) > 0:\n            # print(\"Model grads before\")\n            # print([param.grad.data for param in self.all_fp16_params])\n            # I'm ONLY writing this as an incremental way to make some tests pass until\n            # I can refactor the tests as well.\n            # FP16_Optimizer should not be used by anyone.\n            model_grads = []\n            master_grads = []\n            for model_param, master_param in zip(self.all_fp16_params,\n                                                 self.all_fp32_from_fp16_params):\n                if model_param.grad is not None:\n                    model_grads.append(model_param.grad)\n                    if master_param.grad is None:\n                        master_param.grad = torch.empty_like(master_param)\n                    master_grads.append(master_param.grad)\n            self.loss_scaler.unscale(\n                model_grads,\n                master_grads,\n                self.loss_scaler.loss_scale())\n            # print(\"Master grads after\")\n            # print([param.grad.data for param in self.all_fp32_from_fp16_params])\n        if len(self.all_fp32_from_fp32_params) > 0:\n            model_grads = []\n            master_grads = []\n            for model_param, master_param in zip(self.all_fp32_from_fp32_params,\n                                                 self.all_fp32_from_fp32_params):\n                if model_param.grad is not None:\n                    model_grads.append(model_param.grad)\n                    master_grads.append(master_param.grad)\n            # print(\"Model grads before\")\n            # print([param.grad.data for param in self.all_fp32_from_fp32_params])\n            self.loss_scaler.unscale(\n                model_grads,\n                master_grads,\n                self.loss_scaler.loss_scale())\n            # print(\"Master grads after\")\n            # print([param.grad.data for param in self.all_fp32_from_fp32_params])\n        # quit()\n        self.overflow = self.loss_scaler.update_scale()\n        # torch.cuda.nvtx.range_pop()\n\n\n    def inspect_master_grad_data(self):\n        \"\"\"\n        When running with :class:`FP16_Optimizer`, \n        ``.grad`` attributes of a model's fp16 leaves should not be\n        regarded as truthful, because they might be scaled.  \n        After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,\n        the fp32 master params' ``.grad``\n        attributes will contain valid gradients properly divided by the loss scale.  However, \n        because :class:`FP16_Optimizer` flattens some parameters, accessing them may be \n        nonintuitive.  :attr:`inspect_master_grad_data`\n        allows those gradients to be viewed with shapes corresponding to their associated model leaves.\n\n        Returns:\n            List of lists (one list for each parameter group).  The list for each parameter group\n            is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.                 \n        \"\"\"\n        if self.overflow:\n            print(\"Warning:  calling FP16_Optimizer.inspect_master_grad_data while in an overflow state.  \"\n                  \"Gradients are currently invalid (may be inf, nan, or stale).  Returning None.\")\n            return None\n        else:\n            # The optimizer owns only references to master params.\n            master_grads_data = []\n            for param_group in self.optimizer.param_groups:\n                master_grads_this_group = []\n                for param in param_group['params']:\n                    if param.grad is not None:\n                        master_grads_this_group.append(param.grad.data)\n                    else:\n                        master_grads_this_group.append(None)\n                master_grads_data.append(master_grads_this_group)\n            return master_grads_data\n\n\n    # Promote loss scale so it can be retrieved or set via \"fp16_optimizer_instance.loss_scale\"\n    def _get_loss_scale(self):\n        return self.loss_scaler.loss_scale()\n\n    def _set_loss_scale(self, value):\n        self.loss_scaler._loss_scale = value\n\n    loss_scale = property(_get_loss_scale, _set_loss_scale)\n\n    # Promote state so it can be retrieved or set via \"fp16_optimizer_instance.state\"\n    def _get_state(self):\n        return self.optimizer.state\n\n    def _set_state(self, value):\n        self.optimizer.state = value\n\n    state = property(_get_state, _set_state)\n\n    # Promote param_groups so it can be retrieved or set via \"fp16_optimizer_instance.param_groups\"\n    # (for example, to adjust the learning rate)\n    def _get_param_groups(self):\n        return self.optimizer.param_groups\n\n    def _set_param_groups(self, value):\n        self.optimizer.param_groups = value\n\n    param_groups = property(_get_param_groups, _set_param_groups)\n\n"
  },
  {
    "path": "KoSimCSE/apex/fp16_utils/fp16util.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\n\n\nclass tofp16(nn.Module):\n    \"\"\"\n    Utility module that implements::\n\n        def forward(self, input):\n            return input.half()\n    \"\"\"\n\n    def __init__(self):\n        super(tofp16, self).__init__()\n\n    def forward(self, input):\n        return input.half()\n\n\ndef BN_convert_float(module):\n    \"\"\"\n    Utility function for network_to_half().\n\n    Retained for legacy purposes.\n    \"\"\"\n    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:\n        module.float()\n    for child in module.children():\n        BN_convert_float(child)\n    return module\n\n\ndef network_to_half(network):\n    \"\"\"\n    Convert model to half precision in a batchnorm-safe way.\n\n    Retained for legacy purposes. It is recommended to use FP16Model.\n    \"\"\"\n    return nn.Sequential(tofp16(), BN_convert_float(network.half()))\n\n\ndef convert_module(module, dtype):\n    \"\"\"\n    Converts a module's immediate parameters and buffers to dtype.\n    \"\"\"\n    for param in module.parameters(recurse=False):\n        if param is not None:\n            if param.data.dtype.is_floating_point:\n                param.data = param.data.to(dtype=dtype)\n            if param._grad is not None and param._grad.data.dtype.is_floating_point:\n                param._grad.data = param._grad.data.to(dtype=dtype)\n\n    for buf in module.buffers(recurse=False):\n        if buf is not None and buf.data.dtype.is_floating_point:\n            buf.data = buf.data.to(dtype=dtype)\n\n\ndef convert_network(network, dtype):\n    \"\"\"\n    Converts a network's parameters and buffers to dtype.\n    \"\"\"\n    for module in network.modules():\n        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:\n            continue\n        convert_module(module, dtype)\n        if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):\n            module.flatten_parameters()\n    return network\n\n\nclass FP16Model(nn.Module):\n    \"\"\"\n    Convert model to half precision in a batchnorm-safe way.\n    \"\"\"\n\n    def __init__(self, network):\n        super(FP16Model, self).__init__()\n        self.network = convert_network(network, dtype=torch.half)\n\n    def forward(self, *inputs):\n        inputs = tuple(t.half() for t in inputs)\n        return self.network(*inputs)\n\n\ndef backwards_debug_hook(grad):\n    raise RuntimeError(\"master_params recieved a gradient in the backward pass!\")\n\ndef prep_param_lists(model, flat_master=False):\n    \"\"\"\n    Creates a list of FP32 master parameters for a given model, as in\n    `Training Neural Networks with Mixed Precision:  Real Examples`_.\n\n    Args:\n        model (torch.nn.Module): Existing Pytorch model\n        flat_master (bool, optional, default=False):  Flatten the master parameters into a single tensor, as a performance optimization.\n    Returns:\n        A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`.  ``master_params`` is a list of FP32 master gradients.  If ``flat_master=True``, ``master_params`` will be a list with one element.\n\n    Example::\n\n        model_params, master_params = prep_param_lists(model)\n\n    .. warning::\n        Currently, if ``flat_master=True``, all the model's parameters must be the same type.  If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.\n\n    .. _`Training Neural Networks with Mixed Precision:  Real Examples`:\n        http://on-demand.gputechconf.com/gtc/2018/video/S81012/\n    \"\"\"\n    model_params = [param for param in model.parameters() if param.requires_grad]\n\n    if flat_master:\n        # Give the user some more useful error messages\n        try:\n            # flatten_dense_tensors returns a contiguous flat array.\n            # http://pytorch.org/docs/master/_modules/torch/_utils.html\n            master_params = _flatten_dense_tensors([param.data for param in model_params]).float()\n        except:\n            print(\"Error in prep_param_lists:  model may contain a mixture of parameters \"\n                      \"of different types.  Use flat_master=False, or use F16_Optimizer.\")\n            raise\n        master_params = torch.nn.Parameter(master_params)\n        master_params.requires_grad = True\n        # master_params.register_hook(backwards_debug_hook)\n        if master_params.grad is None:\n            master_params.grad = master_params.new(*master_params.size())\n        return model_params, [master_params]\n    else:\n        master_params = [param.clone().float().detach() for param in model_params]\n        for param in master_params:\n            param.requires_grad = True\n        return model_params, master_params\n\n\ndef model_grads_to_master_grads(model_params, master_params, flat_master=False):\n    \"\"\"\n    Copy model gradients to master gradients.  \n\n    Args:\n        model_params:  List of model parameters created by :func:`prep_param_lists`.\n        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.\n    \"\"\"\n    if flat_master:\n        # The flattening may incur one more deep copy than is necessary.\n        master_params[0].grad.data.copy_(\n            _flatten_dense_tensors([p.grad.data for p in model_params]))\n    else:\n        for model, master in zip(model_params, master_params):\n            if model.grad is not None:\n                if master.grad is None:\n                    master.grad = Variable(master.data.new(*master.data.size()))\n                master.grad.data.copy_(model.grad.data)\n            else:\n                master.grad = None\n\n\ndef master_params_to_model_params(model_params, master_params, flat_master=False):\n    \"\"\"\n    Copy master parameters to model parameters.\n\n    Args:\n        model_params:  List of model parameters created by :func:`prep_param_lists`.\n        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.\n    \"\"\"\n    if flat_master:\n        for model, master in zip(model_params, \n                                 _unflatten_dense_tensors(master_params[0].data, model_params)):\n            model.data.copy_(master)\n    else:\n        for model, master in zip(model_params, master_params):\n            model.data.copy_(master.data)\n\n# Backward compatibility fixes\n\ndef to_python_float(t):\n    if hasattr(t, 'item'):\n        return t.item()\n    else:\n        return t[0]\n\nTORCH_MAJOR = int(torch.__version__.split('.')[0])\nTORCH_MINOR = int(torch.__version__.split('.')[1])\nif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:\n    clip_grad_norm = torch.nn.utils.clip_grad_norm\nelse:\n    clip_grad_norm = torch.nn.utils.clip_grad_norm_\n"
  },
  {
    "path": "KoSimCSE/apex/fp16_utils/loss_scaler.py",
    "content": "import torch\n\n# item() is a recent addition, so this helps with backward compatibility.\ndef to_python_float(t):\n    if hasattr(t, 'item'):\n        return t.item()\n    else:\n        return t[0]\n\nclass LossScaler:\n    \"\"\"\n    Class that manages a static loss scale.  This class is intended to interact with\n    :class:`FP16_Optimizer`, and should not be directly manipulated by the user.\n\n    Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to \n    :class:`FP16_Optimizer`'s constructor.\n\n    Args:\n        scale (float, optional, default=1.0):  The loss scale.\n    \"\"\"\n\n    def __init__(self, scale=1):\n        self.cur_scale = scale\n\n    # `params` is a list / generator of torch.Variable\n    def has_overflow(self, params):\n        return False\n\n    # `x` is a torch.Tensor\n    def _has_inf_or_nan(x):\n        return False\n\n    def update_scale(self, overflow):\n        pass\n\n    @property\n    def loss_scale(self):\n        return self.cur_scale\n\n    def scale_gradient(self, module, grad_in, grad_out):\n        return tuple(self.loss_scale * g for g in grad_in)\n\n    def backward(self, loss, retain_graph=False):\n        scaled_loss = loss*self.loss_scale\n        scaled_loss.backward(retain_graph=retain_graph)\n\nclass DynamicLossScaler:\n    \"\"\"\n    Class that manages dynamic loss scaling.  It is recommended to use :class:`DynamicLossScaler`\n    indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of \n    :class:`FP16_Optimizer`.  However, it's important to understand how :class:`DynamicLossScaler`\n    operates, because the default options can be changed using the\n    the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.\n\n    Loss scaling is designed to combat the problem of underflowing gradients encountered at long\n    times when training fp16 networks.  Dynamic loss scaling begins by attempting a very high loss\n    scale.  Ironically, this may result in OVERflowing gradients.  If overflowing gradients are\n    encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has \n    occurred.\n    :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,\n    and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.  \n    If a certain number of iterations occur without overflowing gradients detected,\n    :class:`DynamicLossScaler` increases the loss scale once more.\n    In this way :class:`DynamicLossScaler` attempts to \"ride the edge\" of \n    always using the highest loss scale possible without incurring overflow.\n\n    Args:\n        init_scale (float, optional, default=2**32):  Initial loss scale attempted by :class:`DynamicLossScaler.`\n        scale_factor (float, optional, default=2.0):  Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``.  If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. \n        scale_window (int, optional, default=1000):  Number of consecutive iterations without an overflow to wait before increasing the loss scale.\n    \"\"\"\n\n    def __init__(self,\n                 init_scale=2**32,\n                 scale_factor=2.,\n                 scale_window=1000):\n        self.cur_scale = init_scale\n        self.cur_iter = 0\n        self.last_overflow_iter = -1\n        self.scale_factor = scale_factor\n        self.scale_window = scale_window\n\n    # `params` is a list / generator of torch.Variable\n    def has_overflow(self, params):\n        for p in params:\n            if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):\n                return True\n\n        return False\n\n    # `x` is a torch.Tensor\n    def _has_inf_or_nan(x):\n        try:\n            # if x is half, the .float() incurs an additional deep copy, but it's necessary if \n            # Pytorch's .sum() creates a one-element tensor of the same type as x \n            # (which is true for some recent version of pytorch).\n            cpu_sum = float(x.float().sum())\n            # More efficient version that can be used if .sum() returns a Python scalar\n            # cpu_sum = float(x.sum())\n        except RuntimeError as instance:\n            # We want to check if inst is actually an overflow exception.\n            # RuntimeError could come from a different error.\n            # If so, we still want the exception to propagate.\n            if \"value cannot be converted\" not in instance.args[0]:\n                raise\n            return True\n        else:\n            if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n                return True\n            return False\n\n    # `overflow` is boolean indicating whether the gradient overflowed\n    def update_scale(self, overflow):\n        if overflow:\n            # self.cur_scale /= self.scale_factor\n            self.cur_scale = max(self.cur_scale/self.scale_factor, 1)\n            self.last_overflow_iter = self.cur_iter\n        else:\n            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:\n                self.cur_scale *= self.scale_factor\n        self.cur_iter += 1\n\n    @property\n    def loss_scale(self):\n        return self.cur_scale\n\n    def scale_gradient(self, module, grad_in, grad_out):\n        return tuple(self.loss_scale * g for g in grad_in)\n\n    def backward(self, loss, retain_graph=False):\n        scaled_loss = loss*self.loss_scale\n        scaled_loss.backward(retain_graph=retain_graph)\n        \n##############################################################        \n# Example usage below here -- assuming it's in a separate file\n##############################################################\n\"\"\"\nTO-DO separate out into an example.\nif __name__ == \"__main__\":\n    import torch\n    from torch.autograd import Variable\n    from dynamic_loss_scaler import DynamicLossScaler\n\n    # N is batch size; D_in is input dimension;\n    # H is hidden dimension; D_out is output dimension.\n    N, D_in, H, D_out = 64, 1000, 100, 10\n\n    # Create random Tensors to hold inputs and outputs, and wrap them in Variables.\n    x = Variable(torch.randn(N, D_in), requires_grad=False)\n    y = Variable(torch.randn(N, D_out), requires_grad=False)\n\n    w1 = Variable(torch.randn(D_in, H), requires_grad=True)\n    w2 = Variable(torch.randn(H, D_out), requires_grad=True)\n    parameters = [w1, w2]\n\n    learning_rate = 1e-6\n    optimizer = torch.optim.SGD(parameters, lr=learning_rate)\n    loss_scaler = DynamicLossScaler()\n\n    for t in range(500):\n        y_pred = x.mm(w1).clamp(min=0).mm(w2)\n        loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale\n        print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))\n        print('Iter {} scaled loss: {}'.format(t, loss.data[0]))\n        print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))\n\n        # Run backprop\n        optimizer.zero_grad()\n        loss.backward()\n        \n        # Check for overflow\n        has_overflow = DynamicLossScaler.has_overflow(parameters)\n        \n        # If no overflow, unscale grad and update as usual\n        if not has_overflow:\n            for param in parameters:\n                param.grad.data.mul_(1. / loss_scaler.loss_scale)\n            optimizer.step()\n        # Otherwise, don't do anything -- ie, skip iteration\n        else:\n            print('OVERFLOW!')\n\n        # Update loss scale for next iteration\n        loss_scaler.update_scale(has_overflow)\n\n\"\"\"\n"
  },
  {
    "path": "KoSimCSE/apex/mlp/__init__.py",
    "content": "from .mlp import *\n"
  },
  {
    "path": "KoSimCSE/apex/mlp/mlp.py",
    "content": "from copy import copy\nimport math\nimport torch\nfrom torch import nn\nimport mlp_cuda\nfrom .. import amp\n\nclass MlpFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, bias, activation, *args):\n        output = mlp_cuda.forward(bias, activation, args)\n        ctx.save_for_backward(*args)\n        ctx.outputs = output\n        ctx.bias = bias\n        ctx.activation = activation\n        return output[0]\n\n    @staticmethod\n    def backward(ctx, grad_o):\n        grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)\n        del ctx.outputs\n        return (None, None, *grads)\n\nmlp_function = amp.half_function(MlpFunction.apply)\n\nclass MLP(torch.nn.Module):\n    \"\"\"Launch MLP in C++\n\n    Args:\n        mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024\n        bias (bool): Default True:\n        relu (bool): Default True\n    \"\"\"\n    def __init__(self, mlp_sizes, bias=True, activation='relu'):\n        super(MLP, self).__init__()\n        self.num_layers = len(mlp_sizes) - 1\n        self.mlp_sizes = copy(mlp_sizes)\n        self.bias = 1 if bias else 0\n\n        if activation is 'none':\n            self.activation = 0\n        elif activation is 'relu':\n            self.activation = 1\n        elif activation is 'sigmoid':\n            self.activation = 2\n        else:\n            raise TypeError(\"activation must be relu or none.\")\n\n        self.weights = []\n        self.biases = []\n        for i in range(self.num_layers):\n            w = torch.nn.Parameter(torch.empty(mlp_sizes[i+1], mlp_sizes[i]))\n            self.weights.append(w)\n            name = 'weight_{}'.format(i)\n            setattr(self, name, w)\n            if self.bias:\n                b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))\n                self.biases.append(b)\n                name = 'bias_{}'.format(i)\n                setattr(self, name, b)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for weight in self.weights:\n            dimsum = weight.size(0) + weight.size(1)\n            std = math.sqrt(2. / float(dimsum))\n            nn.init.normal_(weight, 0., std)\n        if self.bias:\n            for bias in self.biases:\n                std = math.sqrt(1. / float(bias.size(0)))\n                nn.init.normal_(bias, 0., std)\n\n    def forward(self, input):\n        return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)\n\n    def extra_repr(self):\n        s = F\"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}\"\n        return s\n"
  },
  {
    "path": "KoSimCSE/apex/multi_tensor_apply/__init__.py",
    "content": "from .multi_tensor_apply import MultiTensorApply\n\nmulti_tensor_applier = MultiTensorApply(2048*32)\n\n"
  },
  {
    "path": "KoSimCSE/apex/multi_tensor_apply/multi_tensor_apply.py",
    "content": "import torch\n\nclass MultiTensorApply(object):\n    available = False\n    warned = False\n\n    def __init__(self, chunk_size):\n        try:\n            import amp_C\n            MultiTensorApply.available = True\n            self.chunk_size = chunk_size\n        except ImportError as err:\n            MultiTensorApply.available = False\n            MultiTensorApply.import_err = err\n\n    def check_avail(self):\n        if MultiTensorApply.available == False:\n            raise RuntimeError(\n                \"Attempted to call MultiTensorApply method, but MultiTensorApply \"\n                \"is not available, possibly because Apex was installed without \"\n                \"--cpp_ext --cuda_ext.  Original import error message:\",\n                MultiTensorApply.import_err)\n\n    def __call__(self, op, noop_flag_buffer, tensor_lists, *args):\n        self.check_avail()\n\n        return op(self.chunk_size,\n                  noop_flag_buffer,\n                  tensor_lists,\n                  *args)\n"
  },
  {
    "path": "KoSimCSE/apex/normalization/__init__.py",
    "content": "from .fused_layer_norm import FusedLayerNorm\n"
  },
  {
    "path": "KoSimCSE/apex/normalization/fused_layer_norm.py",
    "content": "import math\nimport torch\nimport numbers\nfrom torch.nn.parameter import Parameter\nfrom torch.nn import init\nfrom torch.nn import functional as F\nimport importlib\n\nglobal fused_layer_norm_cuda\nfused_layer_norm_cuda = None\n\nclass FusedLayerNormAffineFunction(torch.autograd.Function):\n\n  @staticmethod\n  def forward(ctx, input, weight, bias, normalized_shape, eps):\n    global fused_layer_norm_cuda\n    if fused_layer_norm_cuda is None:\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n    ctx.normalized_shape = normalized_shape\n    ctx.eps = eps\n    input_ = input.contiguous()\n    weight_ = weight.contiguous()\n    bias_ = bias.contiguous()\n    output, mean, invvar = fused_layer_norm_cuda.forward_affine(\n        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)\n    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)\n    return output\n\n  @staticmethod\n  def backward(ctx, grad_output):\n    input_, weight_, bias_, mean, invvar = ctx.saved_tensors\n    grad_input = grad_weight = grad_bias = None\n    grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(\n        grad_output.contiguous(), mean, invvar,\n        input_, ctx.normalized_shape,\n        weight_, bias_, ctx.eps)\n    return grad_input, grad_weight, grad_bias, None, None\n\nclass FusedLayerNormFunction(torch.autograd.Function):\n\n  @staticmethod\n  def forward(ctx, input, normalized_shape, eps):\n    global fused_layer_norm_cuda\n    if fused_layer_norm_cuda is None:\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n    ctx.normalized_shape = normalized_shape\n    ctx.eps = eps\n    input_ = input.contiguous()\n    output, mean, invvar = fused_layer_norm_cuda.forward(\n        input_, ctx.normalized_shape, ctx.eps)\n    ctx.save_for_backward(input_, mean, invvar)\n    return output\n\n  @staticmethod\n  def backward(ctx, grad_output):\n    input_, mean, invvar = ctx.saved_tensors\n    grad_input = None\n    grad_input = fused_layer_norm_cuda.backward(\n        grad_output.contiguous(), mean, invvar,\n        input_, ctx.normalized_shape,\n        ctx.eps)\n    return grad_input, None, None\n\ndef fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):\n    return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)\n\ndef fused_layer_norm(input, normalized_shape, eps=1e-6):\n    return FusedLayerNormFunction.apply(input, normalized_shape, eps)\n\nclass FusedLayerNorm(torch.nn.Module):\n    r\"\"\"Applies Layer Normalization over a mini-batch of inputs as described in\n    the paper `Layer Normalization`_ .\n\n    Currently only runs on cuda() tensors.\n\n    .. math::\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\n\n    The mean and standard-deviation are calculated separately over the last\n    certain number dimensions which have to be of the shape specified by\n    :attr:`normalized_shape`.\n    :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of\n    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.\n\n    .. note::\n        Unlike Batch Normalization and Instance Normalization, which applies\n        scalar scale and bias for each entire channel/plane with the\n        :attr:`affine` option, Layer Normalization applies per-element scale and\n        bias with :attr:`elementwise_affine`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    Args:\n        normalized_shape (int or list or torch.Size): input shape from an expected input\n            of size\n\n            .. math::\n                [* \\times \\text{normalized}\\_\\text{shape}[0] \\times \\text{normalized}\\_\\text{shape}[1]\n                    \\times \\ldots \\times \\text{normalized}\\_\\text{shape}[-1]]\n\n            If a single integer is used, it is treated as a singleton list, and this module will\n            normalize over the last dimension which is expected to be of that specific size.\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        elementwise_affine: a boolean value that when set to ``True``, this module\n            has learnable per-element affine parameters initialized to ones (for weights)\n            and zeros (for biases). Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n\n    Examples::\n\n        >>> input = torch.randn(20, 5, 10, 10)\n        >>> # With Learnable Parameters\n        >>> m = apex.normalization.FusedLayerNorm(input.size()[1:])\n        >>> # Without Learnable Parameters\n        >>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)\n        >>> # Normalize over last two dimensions\n        >>> m = apex.normalization.FusedLayerNorm([10, 10])\n        >>> # Normalize over last dimension of size 10\n        >>> m = apex.normalization.FusedLayerNorm(10)\n        >>> # Activating the module\n        >>> output = m(input)\n\n    .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450\n    \"\"\"\n    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n        super(FusedLayerNorm, self).__init__()\n\n        global fused_layer_norm_cuda\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.eps = eps\n        self.elementwise_affine = elementwise_affine\n        if self.elementwise_affine:\n            self.weight = Parameter(torch.Tensor(*normalized_shape))\n            self.bias = Parameter(torch.Tensor(*normalized_shape))\n        else:\n            self.register_parameter('weight', None)\n            self.register_parameter('bias', None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.elementwise_affine:\n            init.ones_(self.weight)\n            init.zeros_(self.bias)\n\n    def forward(self, input):\n        if not input.is_cuda:\n            return  F.layer_norm(\n                input, self.normalized_shape, self.weight, self.bias, self.eps)\n        if self.elementwise_affine:\n          return FusedLayerNormAffineFunction.apply(\n              input, self.weight, self.bias, self.normalized_shape,self.eps)\n        else:\n          return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)\n\n    def extra_repr(self):\n        return '{normalized_shape}, eps={eps}, ' \\\n            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)\n"
  },
  {
    "path": "KoSimCSE/apex/optimizers/__init__.py",
    "content": "from .fused_sgd import FusedSGD\nfrom .fused_adam import FusedAdam\nfrom .fused_novograd import FusedNovoGrad\nfrom .fused_lamb import FusedLAMB\nfrom .fused_adagrad import FusedAdagrad"
  },
  {
    "path": "KoSimCSE/apex/optimizers/fused_adagrad.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedAdagrad(torch.optim.Optimizer):\n    \"\"\"Implements Adagrad algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused Adagrad implements 2 fusions.\n      * Fusion of the Adagrad update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedAdagrad`'s usage is identical to any ordinary Pytorch optimizer::\n        opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedAdagrad` may be used with or without Amp.  If you wish to use :class:`FusedAdagrad` with Amp,\n    you may choose any ``opt_level``::\n        opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    It has been proposed in `Adaptive Subgradient Methods for Online Learning\n    and Stochastic Optimization`_.\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-2)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-10)\n        adagrad_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay (also known as AdamW) (default: False)\n\n    .. _Adaptive Subgradient Methods for Online Learning and Stochastic\n        Optimization: http://jmlr.org/papers/v12/duchi11a.html\n    \"\"\"\n    def __init__(self, params, lr=1e-2, eps=1e-10,\n                 weight_decay=0., set_grad_none=True, adagrad_w_mode=False):\n\n        defaults = dict(lr=lr, eps=eps, weight_decay=weight_decay)\n        super(FusedAdagrad, self).__init__(params, defaults)\n        self.adagrad_w_mode = 1 if adagrad_w_mode else 0\n        self.set_grad_none = set_grad_none\n\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_adagrad = amp_C.multi_tensor_adagrad\n        else:\n            raise RuntimeError('apex.optimizers.FusedAdagrad requires cuda extensions')\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedAdagrad, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            # create lists for multi-tensor apply\n            g_16, p_16, h_16 = [], [], []\n            g_32, p_32, h_32 = [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedAdagrad does not support sparse gradients')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['sum'] = torch.zeros_like(p.data)\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    h_16.append(state['sum'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    h_32.append(state['sum'])\n                else:\n                    raise RuntimeError('FusedAdagrad only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_adagrad,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, h_16],\n                                     group['lr'],\n                                     group['eps'],\n                                     self.adagrad_w_mode,\n                                     group['weight_decay'])\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_adagrad,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, h_32],\n                                     group['lr'],\n                                     group['eps'],\n                                     self.adagrad_w_mode,\n                                     group['weight_decay'])\n\n        return loss"
  },
  {
    "path": "KoSimCSE/apex/optimizers/fused_adam.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedAdam(torch.optim.Optimizer):\n\n    \"\"\"Implements Adam algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused Adam implements 2 fusions.\n\n      * Fusion of the Adam update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,\n    or ``torch.optim.Adam`` with ``adam_w_mode=False``::\n\n        opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedAdam` may be used with or without Amp.  If you wish to use :class:`FusedAdam` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n\n    .. warning::\n        A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``.  These additional arguments\n        are now deprecated and unnecessary.\n\n    Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n\n    .. _Adam - A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,\n                 weight_decay=0., amsgrad=False, set_grad_none=True):\n\n        if amsgrad:\n            raise RuntimeError('FusedAdam does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay)\n        super(FusedAdam, self).__init__(params, defaults)\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_adam = amp_C.multi_tensor_adam\n        else:\n            raise RuntimeError('apex.optimizers.FusedAdam requires cuda extensions')\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedAdam, self).zero_grad()\n\n    def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n\n        The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.\n        \"\"\"\n        if any(p is not None for p in [grads, output_params, scale, grad_norms]):\n            raise RuntimeError('FusedAdam has been updated.  Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                    v_16.append(state['exp_avg_sq'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                    v_32.append(state['exp_avg_sq'])\n                else:\n                    raise RuntimeError('FusedAdam only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_adam,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16, v_16],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     self.adam_w_mode,\n                                     bias_correction,\n                                     group['weight_decay'])\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_adam,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32, v_32],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     self.adam_w_mode,\n                                     bias_correction,\n                                     group['weight_decay'])\n\n\n        return loss\n"
  },
  {
    "path": "KoSimCSE/apex/optimizers/fused_lamb.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedLAMB(torch.optim.Optimizer):\n\n    \"\"\"Implements LAMB algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused LAMB implements 2 fusions.\n\n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n        use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0\n            weight decay parameter (default: False)\n\n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,\n                 amsgrad=False, adam_w_mode=True,\n                 grad_averaging=True, set_grad_none=True,\n                 max_grad_norm=1.0, use_nvlamb=False):\n        if amsgrad:\n            raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging,\n                        max_grad_norm=max_grad_norm)\n        super(FusedLAMB, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device)\n            self.multi_tensor_lamb = amp_C.multi_tensor_lamb\n        else:\n            raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')\n\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n        self.use_nvlamb = use_nvlamb\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedLAMB, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # create separate grad lists for fp32 and fp16 params\n        g_all_32, g_all_16 = [], []\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.dtype == torch.float32:\n                    g_all_32.append(p.grad.data)\n                elif p.dtype == torch.float16:\n                    g_all_16.append(p.grad.data)\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n        device = self.param_groups[0][\"params\"][0].device\n        g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)\n        # compute grad norm for two lists\n        if len(g_all_32) > 0:\n            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_32], False)[0]\n        if len(g_all_16) > 0:\n            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,\n                                             self._dummy_overflow_buf,\n                                             [g_all_16], False)[0]\n\n        # blend two grad norms to get global grad norm\n        global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,\n                                                self._dummy_overflow_buf,\n                                                [[g_norm_32, g_norm_16]],\n                                                False)[0]\n        max_grad_norm = self.defaults['max_grad_norm']\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n            grad_averaging = 1 if group['grad_averaging'] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                    v_16.append(state['exp_avg_sq'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                    v_32.append(state['exp_avg_sq'])\n                else:\n                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16, v_16],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm,\n                                     self.use_nvlamb)\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_lamb,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32, v_32],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.adam_w_mode,\n                                     global_grad_norm,\n                                     max_grad_norm,\n                                     self.use_nvlamb)\n\n        return loss\n"
  },
  {
    "path": "KoSimCSE/apex/optimizers/fused_novograd.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedNovoGrad(torch.optim.Optimizer):\n\n    \"\"\"Implements NovoGrad algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused NovoGrad implements 2 fusions.\n\n      * Fusion of the NovoGrad update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedNovoGrad`'s usage is identical to any Pytorch optimizer::\n\n        opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedNovoGrad` may be used with or without Amp.  If you wish to use :class:`FusedNovoGrad` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    It has been proposed in `Jasper: An End-to-End Convolutional Neural Acoustic Model`_.\n    More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        reg_inside_moment (bool, optional): whether do regularization (norm and L2)\n            in momentum calculation. True for include, False for not include and\n            only do it on update term. (default: False)\n        grad_averaging (bool, optional): whether apply (1-beta1) to grad when\n            calculating running averages of gradient. (default: True)\n        norm_type (int, optional): which norm to calculate for each layer.\n            2 for L2 norm, and 0 for infinite norm. These 2 are only supported\n            type now. (default: 2)\n        init_zero (bool, optional): whether init norm with 0 (start averaging on\n            1st step) or first step norm (start averaging on 2nd step). True for\n            init with 0. (default: False)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n\n    .. _Jasper - An End-to-End Convolutional Neural Acoustic Model:\n        https://arxiv.org/abs/1904.03288\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, bias_correction=True,\n                 betas=(0.9, 0.999), eps=1e-8, weight_decay=0.,\n                 amsgrad=False, reg_inside_moment=False,\n                 grad_averaging=True, norm_type=2, init_zero=False,\n                 set_grad_none=True):\n        if amsgrad:\n            raise RuntimeError('FusedNovoGrad does not support the AMSGrad variant.')\n        defaults = dict(lr=lr, bias_correction=bias_correction,\n                        betas=betas, eps=eps, weight_decay=weight_decay,\n                        grad_averaging=grad_averaging, norm_type=norm_type,\n                        init_zero=init_zero)\n        super(FusedNovoGrad, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n\n            # Creating the overflow buffer on the same device as the params tensors.\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device)\n            self.multi_tensor_novograd = amp_C.multi_tensor_novograd\n        else:\n            raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions')\n\n        self.moment_mode = 0 if reg_inside_moment else 1\n        self.set_grad_none = set_grad_none\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedNovoGrad, self).zero_grad()\n\n    def load_state_dict(self, state_dict):\n        super(FusedNovoGrad, self).load_state_dict(state_dict)\n        # in case exp_avg_sq is not on the same device as params, move it there\n        for group in self.param_groups:\n            if len(group['params']) > 0:\n                group['exp_avg_sq'][0] = group['exp_avg_sq'][0].to(group['params'][0].device)\n                group['exp_avg_sq'][1] = group['exp_avg_sq'][1].to(group['params'][0].device)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            bias_correction = 1 if group['bias_correction'] else 0\n            beta1, beta2 = group['betas']\n            grad_averaging = 1 if group['grad_averaging'] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16 = [], [], []\n            g_32, p_32, m_32 = [], [], []\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError('FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state['exp_avg'])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state['exp_avg'])\n                else:\n                    raise RuntimeError('FusedNovoGrad only support fp16 and fp32.')\n\n            # we store per weight norm as one tensor for one group/precision combination\n            # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types\n            if 'exp_avg_sq' not in group:\n                group['exp_avg_sq'] = [None, None]\n                if group['init_zero']:\n                    # Creating the following parameters on the same device as the params tensors.\n                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16), device=self.param_groups[0][\"params\"][0].device).contiguous().fill_(0)\n                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32), device=self.param_groups[0][\"params\"][0].device).contiguous().fill_(0)\n                else: # init with first step norm, so first blend have no effect\n                    if group['norm_type'] == 0:\n                        v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]\n                        v_32 = [torch.max(torch.abs(g)).item() for g in g_32]\n                    elif group['norm_type'] == 2:\n                        v_16 = [torch.sum(torch.pow(g.to(torch.float32), 2)).sqrt().item() for g in g_16]\n                        v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]\n                    else:\n                        raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')\n                    # Creating the following parameters on the same device as the params tensors.\n                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16, device=self.param_groups[0][\"params\"][0].device)\n                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32, device=self.param_groups[0][\"params\"][0].device)\n            else:\n                assert(len(g_16) == group['exp_avg_sq'][0].numel())\n                assert(len(g_32) == group['exp_avg_sq'][1].numel())\n\n            if(len(g_16) > 0):\n                multi_tensor_applier(self.multi_tensor_novograd,\n                                     self._dummy_overflow_buf,\n                                     [g_16, p_16, m_16],\n                                     group['exp_avg_sq'][0],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.moment_mode,\n                                     group['norm_type'])\n            if(len(g_32) > 0):\n                multi_tensor_applier(self.multi_tensor_novograd,\n                                     self._dummy_overflow_buf,\n                                     [g_32, p_32, m_32],\n                                     group['exp_avg_sq'][1],\n                                     group['lr'],\n                                     beta1,\n                                     beta2,\n                                     group['eps'],\n                                     group['step'],\n                                     bias_correction,\n                                     group['weight_decay'],\n                                     grad_averaging,\n                                     self.moment_mode,\n                                     group['norm_type'])\n\n\n        return loss\n"
  },
  {
    "path": "KoSimCSE/apex/optimizers/fused_sgd.py",
    "content": "import torch\nfrom torch.optim.optimizer import Optimizer, required\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nclass FusedSGD(Optimizer):\n    r\"\"\"Implements stochastic gradient descent (optionally with momentum).\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused SGD implements 2 fusions.\n\n      * Fusion of the SGD update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``::\n\n        opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedSGD` may be used with or without Amp.  If you wish to use :class:`FusedSGD` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    Nesterov momentum is based on the formula from\n    `On the importance of initialization and momentum in deep learning`__.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float): learning rate\n        momentum (float, optional): momentum factor (default: 0)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        dampening (float, optional): dampening for momentum (default: 0)\n        nesterov (bool, optional): enables Nesterov momentum (default: False)\n\n    Example:\n        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n        >>> optimizer.zero_grad()\n        >>> loss_fn(model(input), target).backward()\n        >>> optimizer.step()\n\n    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf\n\n    .. note::\n        The implementation of SGD with Momentum/Nesterov subtly differs from\n        Sutskever et. al. and implementations in some other frameworks.\n\n        Considering the specific case of Momentum, the update can be written as\n\n        .. math::\n                  v = \\rho * v + g \\\\\n                  p = p - lr * v\n\n        where p, g, v and :math:`\\rho` denote the parameters, gradient,\n        velocity, and momentum respectively.\n\n        This is in contrast to Sutskever et. al. and\n        other frameworks which employ an update of the form\n\n        .. math::\n             v = \\rho * v + lr * g \\\\\n             p = p - v\n\n        The Nesterov version is analogously modified.\n    \"\"\"\n\n    def __init__(self, params, lr=required, momentum=0, dampening=0,\n                 weight_decay=0, nesterov=False,\n                 wd_after_momentum=False,\n                 materialize_master_grads=True,\n                 set_grad_none=False):\n        if lr is not required and lr < 0.0:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if momentum < 0.0:\n            raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n        if weight_decay < 0.0:\n            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n\n        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,\n                        weight_decay=weight_decay, nesterov=nesterov)\n        if nesterov and (momentum <= 0 or dampening != 0):\n            raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n        super(FusedSGD, self).__init__(params, defaults)\n\n        self.wd_after_momentum = wd_after_momentum\n        self.materialize_master_grads = materialize_master_grads\n        self.most_recent_scale = 1.0\n        self.scale_set_by_backward = False\n        self.set_grad_none = set_grad_none\n\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device)\n            self.multi_tensor_sgd = amp_C.multi_tensor_sgd\n        else:\n            raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')\n\n    def __setstate__(self, state):\n        super(FusedSGD, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('nesterov', False)\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group['params']:\n                    p.grad = None\n        else:\n            super(FusedSGD, self).zero_grad()\n\n    def get_momentums(self, params):\n        momentums = []\n        first_run = True\n        for p in params:\n            param_state = self.state[p]\n            # torch.optim.SGD initializes momentum in the main loop, we have\n            # to do it here, and track whether or not we've done so, so that\n            # momentum application can be skipped in the main kernel.\n            if 'momentum_buffer' not in param_state:\n                first_run = True\n                buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)\n                momentums.append(buf)\n            else:\n                first_run = False\n                momentums.append(param_state['momentum_buffer'])\n        return momentums, first_run\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        explicit_master_params = (hasattr(self, \"_amp_stash\") and\n                                  hasattr(self._amp_stash, \"fp32_from_fp16_groups\"))\n\n        for gid, group in enumerate(self.param_groups):\n            weight_decay = group['weight_decay']\n            momentum = group['momentum']\n            dampening = group['dampening']\n            nesterov = group['nesterov']\n\n\n            # For each group, there are 3 possible combinations we need to consider:\n            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy\n            # 1. fp16, fp16, fp16, No\n            # 2. fp32, fp32, fp32, No\n            # 3. fp16, fp32, fp32, Yes\n\n            first_runs = [True, True]\n\n            # I think a bit of code divergence in exchange for naming clarity is worthwhile\n            if explicit_master_params:\n                stash = self._amp_stash\n\n                fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]\n                fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]\n                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n\n                if self.materialize_master_grads:\n                    fp16_model_params = [p for i, p in enumerate(\n                        stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]\n                    fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]\n                    fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]\n                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)\n\n                    fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,\n                                fp32_from_fp16_momentums, fp16_model_params]\n                else:\n                    fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]\n                    fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]\n                    fp32_from_fp16_params = [p for i, p in enumerate(\n                        stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]\n                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)\n\n                    fp16_set = [fp16_model_grads, fp32_from_fp16_params,\n                                fp32_from_fp16_momentums, fp16_model_params]\n\n                launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]\n            else:\n                fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]\n                fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]\n                fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)\n\n                fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]\n                fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]\n                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n\n                launch_sets = [[fp16_grads, fp16_params, fp16_momentums],\n                               [fp32_grads, fp32_params, fp32_momentums]]\n\n            for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):\n                assert len(launch_set[0]) == len(launch_set[1])\n                assert len(launch_set[0]) == len(launch_set[2])\n                if len(launch_set[0]) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_sgd,\n                        self._dummy_overflow_buf,\n                        launch_set,\n                        weight_decay,\n                        momentum,\n                        dampening,\n                        group['lr'],\n                        nesterov,\n                        first_run,\n                        self.wd_after_momentum,\n                        1.0/self.most_recent_scale)\n\n        self.most_recent_scale = 1.0\n        self.scale_set_by_backward = False\n\n        return loss\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/LARC.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn.parameter import Parameter\n\nclass LARC(object):\n    \"\"\"\n    :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,\n    in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive \n    local learning rate for each individual parameter. The algorithm is designed to improve\n    convergence of large batch training.\n     \n    See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.\n\n    In practice it modifies the gradients of parameters as a proxy for modifying the learning rate\n    of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.\n\n    ```\n    model = ...\n    optim = torch.optim.Adam(model.parameters(), lr=...)\n    optim = LARC(optim)\n    ```\n\n    It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.\n\n    ```\n    model = ...\n    optim = torch.optim.Adam(model.parameters(), lr=...)\n    optim = LARC(optim)\n    optim = apex.fp16_utils.FP16_Optimizer(optim)\n    ```\n\n    Args:\n        optimizer: Pytorch optimizer to wrap and modify learning rate for.\n        trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888\n        clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.\n        eps: epsilon kludge to help with numerical stability while calculating adaptive_lr\n    \"\"\"\n\n    def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):\n        self.optim = optimizer\n        self.trust_coefficient = trust_coefficient\n        self.eps = eps\n        self.clip = clip\n\n    def __getstate__(self):\n        return self.optim.__getstate__()\n\n    def __setstate__(self, state):\n        self.optim.__setstate__(state)\n\n    @property\n    def state(self):\n        return self.optim.state\n\n    def __repr__(self):\n        return self.optim.__repr__()\n\n    @property\n    def param_groups(self):\n        return self.optim.param_groups\n\n    @param_groups.setter\n    def param_groups(self, value):\n        self.optim.param_groups = value\n    \n    def state_dict(self):\n        return self.optim.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self.optim.load_state_dict(state_dict)\n\n    def zero_grad(self):\n        self.optim.zero_grad()\n\n    def add_param_group(self, param_group):\n        self.optim.add_param_group( param_group)\n\n    def step(self):\n        with torch.no_grad():\n            weight_decays = []\n            for group in self.optim.param_groups:\n                # absorb weight decay control from optimizer\n                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0\n                weight_decays.append(weight_decay)\n                group['weight_decay'] = 0\n                for p in group['params']:\n                    if p.grad is None:\n                        continue\n                    param_norm = torch.norm(p.data)\n                    grad_norm = torch.norm(p.grad.data)\n\n                    if param_norm != 0 and grad_norm != 0:\n                        # calculate adaptive lr + weight decay\n                        adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)\n\n                        # clip learning rate for LARC\n                        if self.clip:\n                            # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`\n                            adaptive_lr = min(adaptive_lr/group['lr'], 1)\n\n                        p.grad.data += weight_decay * p.data\n                        p.grad.data *= adaptive_lr\n\n        self.optim.step()\n        # return weight decay control to optimizer\n        for i, group in enumerate(self.optim.param_groups):\n            group['weight_decay'] = weight_decays[i]\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/README.md",
    "content": "## Distributed Data Parallel\n\ndistributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library.\n\n`apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with\ncomputation in the backward pass and bucketing smaller transfers to reduce the total number of\ntransfers required.\n\nmultiproc.py contains the source code for `apex.parallel.multiproc`, a launch utility that places one process on each of the node's available GPUs.\n\n#### [API Documentation](https://nvidia.github.io/apex/parallel.html)\n\n#### [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed)\n\n#### [Imagenet example with Mixed Precision](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)\n\n#### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex)\n\n### Synchronized Batch Normalization\n\n`apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`.\nIt reduces stats on the first (channel) dimension of the Tensor and accepts\narbitrary spatial dimensions.\n\n#### Installation\n\nApex provides two sync BN implementation:\n\n1. There is the Python-only implementation, which is the default implementation\nwhen install with `python setup.py install`.\nIt uses PyTorch primitive operations and distributed communication package from\n`torch.distributed`.\n\n   - _Python-only implementation requires input tensor to be of same data type as\nlayer_\n\n2. We also provide implementation with kernels through CUDA/C++ extension with\nimproved performance. We are experimenting with Welford and Kahan for reduction\nhoping to get better accuracy.\n   To use the kernel implementation, user need to install Apex with CUDA extension\nenabled `python setup.py install --cuda_ext`.\n\n   - _Custom kernel implementation supports fp16 input with fp32 layer as cudnn.\nThis is required to run imagenet example in fp16._\n\n   - _Currently kernel implementation only supports GPU._\n\n#### HowTo\n\n1. User could use `apex.parallel.SyncBatchNorm` by building their module with\nthe layer explicitly.\n\n```\nimport apex\ninput_t = torch.randn(3, 5, 20).cuda()\nsbn = apex.parallel.SyncBatchNorm(5).cuda()\noutput_t = sbn(input)\n```\n\n2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`.\n\n```\n# model is an instance of torch.nn.Module\nimport apex\nsync_bn_model = apex.parallel.convert_syncbn_model(model)\n```\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/__init__.py",
    "content": "import torch\n\nif hasattr(torch.distributed, 'ReduceOp'):\n    ReduceOp = torch.distributed.ReduceOp\nelif hasattr(torch.distributed, 'reduce_op'):\n    ReduceOp = torch.distributed.reduce_op\nelse:\n    ReduceOp = torch.distributed.deprecated.reduce_op\n\nfrom .distributed import DistributedDataParallel, Reducer\n# This is tricky because I'd like SyncBatchNorm to be exposed the same way\n# for both the cuda-enabled and python-fallback versions, and I don't want\n# to suppress the error information.\ntry:\n    import syncbn\n    from .optimized_sync_batchnorm import SyncBatchNorm\nexcept ImportError as err:\n    from .sync_batchnorm import SyncBatchNorm\n    SyncBatchNorm.syncbn_import_error = err\n\ndef convert_syncbn_model(module, process_group=None, channel_last=False):\n    '''\n    Recursively traverse module and its children to replace all instances of\n    ``torch.nn.modules.batchnorm._BatchNorm`` with :class:`apex.parallel.SyncBatchNorm`.\n\n    All ``torch.nn.BatchNorm*N*d`` wrap around\n    ``torch.nn.modules.batchnorm._BatchNorm``, so this function lets you easily switch\n    to use sync BN.\n\n    Args:\n        module (torch.nn.Module): input module\n\n    Example::\n\n        >>> # model is an instance of torch.nn.Module\n        >>> import apex\n        >>> sync_bn_model = apex.parallel.convert_syncbn_model(model)\n    '''\n    mod = module\n    if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):\n        return module\n    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):\n        mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)\n        mod.running_mean = module.running_mean\n        mod.running_var = module.running_var\n        mod.num_batches_tracked = module.num_batches_tracked\n        if module.affine:\n            mod.weight.data = module.weight.data.clone().detach()\n            mod.bias.data = module.bias.data.clone().detach()\n    for name, child in module.named_children():\n        mod.add_module(name, convert_syncbn_model(child,\n                                                  process_group=process_group,\n                                                  channel_last=channel_last))\n    # TODO(jie) should I delete model explicitly?\n    del module\n    return mod\n\ndef create_syncbn_process_group(group_size):\n    '''\n    Creates process groups to be used for syncbn of a give ``group_size`` and returns\n    process group that current GPU participates in.\n\n    ``group_size`` must divide the total number of GPUs (world_size).\n\n    ``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned.\n\n    ``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead.\n\n    Args:\n        group_size (int): number of GPU's to collaborate for sync bn\n\n    Example::\n\n        >>> # model is an instance of torch.nn.Module\n        >>> import apex\n        >>> group = apex.parallel.create_syncbn_process_group(group_size)\n    '''\n\n    if group_size==0:\n        return None\n\n    world_size = torch.distributed.get_world_size()\n    assert(world_size >= group_size)\n    assert(world_size % group_size == 0)\n\n    group=None\n    for group_num in (range(world_size//group_size)):\n        group_ids = range(group_num*group_size, (group_num+1)*group_size)\n        cur_group = torch.distributed.new_group(ranks=group_ids)\n        if (torch.distributed.get_rank()//group_size == group_num):\n            group = cur_group\n            #can not drop out and return here, every process must go through creation of all subgroups\n\n    assert(group is not None)\n    return group\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/distributed.py",
    "content": "import torch\nimport torch.distributed as dist\nfrom torch.nn.modules import Module\nfrom torch.autograd import Variable\nfrom collections import OrderedDict\nfrom itertools import chain\nimport copy\nimport importlib\nfrom ..multi_tensor_apply import multi_tensor_applier\n\nimported_flatten_impl = False\n\ndef import_flatten_impl():\n    global flatten_impl, unflatten_impl, imported_flatten_impl\n    try:\n        import apex_C\n        flatten_impl = apex_C.flatten\n        unflatten_impl = apex_C.unflatten\n    except ImportError:\n        print(\"Warning:  apex was installed without --cpp_ext.  Falling back to Python flatten and unflatten.\")\n        flatten_impl = torch._utils._flatten_dense_tensors\n        unflatten_impl = torch._utils._unflatten_dense_tensors\n    imported_flatten_impl = True\n\ndef flatten(bucket):\n    if not imported_flatten_impl:\n        import_flatten_impl()\n    return flatten_impl(bucket)\n\ndef unflatten(coalesced, bucket):\n    if not imported_flatten_impl:\n        import_flatten_impl()\n    return unflatten_impl(coalesced, bucket)\n\n# apply_dist_call requires that tensors in 'bucket' are all the same type.\ndef apply_flat_dist_call(bucket, call, extra_args=None):\n\n    coalesced = flatten(bucket)\n\n    if extra_args is not None:\n        call(coalesced, *extra_args)\n    else:\n        call(coalesced)\n\n    if call is dist.all_reduce:\n        coalesced /= dist.get_world_size()\n\n    for buf, synced in zip(bucket, unflatten(coalesced, bucket)):\n        buf.copy_(synced)\n\ndef split_half_float_double(tensors):\n    dtypes = [\"torch.cuda.HalfTensor\",  \"torch.cuda.FloatTensor\", \"torch.cuda.DoubleTensor\"]\n    buckets = []\n    for i, dtype in enumerate(dtypes):\n        bucket = [t for t in tensors if t.type() == dtype]\n        if bucket:\n            buckets.append(bucket)\n    return buckets\n\ndef split_by_type(tensors):\n    buckets = OrderedDict()\n    for tensor in tensors:\n        tp = tensor.type()\n        if tp not in buckets:\n            buckets[tp] = []\n        buckets[tp].append(tensor)\n    return buckets\n\n# flat_dist_call organizes 'tensors' by type.\ndef flat_dist_call(tensors, call, extra_args=None):\n    buckets = split_by_type(tensors)\n\n    for tp in buckets:\n        bucket = buckets[tp]\n        apply_flat_dist_call(bucket, call, extra_args)\n\n\ndef extract_tensors(maybe_tensor, tensor_list):\n    if torch.is_tensor(maybe_tensor):\n        tensor_list.append(maybe_tensor)\n    else:\n        try:\n            for item in maybe_tensor:\n                extract_tensors(item, tensor_list)\n        except TypeError:\n            return\n\n\nclass Reducer(object):\n    \"\"\"\n    :class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters\n    across processes.  :class:`Reducer` is intended to give the user additional control:\n    Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce\n    parameters during ``backward()``.\n    Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.\n    This enables, for example, delaying the allreduce to be carried out every\n    several iterations instead of every single iteration.\n\n    Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces\n    over the number of participating processes.\n\n    :class:`Reducer` is designed to work with the upstream launch utility script\n    ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.\n    When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.\n    It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.\n\n    Args:\n        module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced.  If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values.  If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.\n    \"\"\"\n\n    def __init__(self, module_or_grads_list):\n        if isinstance(module_or_grads_list, Module):\n            self.module = module_or_grads_list\n            flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )\n\n        else:\n            self.module = None\n            self.grads = []\n            extract_tensors(module_or_grads_list, self.grads)\n\n    def reduce(self):\n        if self.module:\n            grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]\n            flat_dist_call(grads, dist.all_reduce)\n        else:\n            flat_dist_call(self.grads, dist.all_reduce)\n\n\nclass DistributedDataParallel(Module):\n    \"\"\"\n    :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables\n    easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``.  Parameters are broadcast across participating processes on initialization, and gradients are\n    allreduced and averaged over processes during ``backward()``.\n\n    :class:`DistributedDataParallel` is optimized for use with NCCL.  It achieves high performance by\n    overlapping communication with computation during ``backward()`` and bucketing smaller gradient\n    transfers to reduce the total number of transfers required.\n\n    :class:`DistributedDataParallel` is designed to work with the upstream launch utility script\n    ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.\n    When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.\n    It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.\n\n    https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed shows detailed usage.\n    https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example\n    that combines :class:`DistributedDataParallel` with mixed precision training.\n\n    Args:\n        module: Network definition to be run in multi-gpu/distributed mode.\n        message_size (int, default=1e7): Minimum number of elements in a communication bucket.\n        delay_allreduce (bool, default=False):  Delay all communication to the end of the backward pass.  This disables overlapping communication with computation.\n        allreduce_trigger_params (list, optional, default=None):  If supplied, should contain a list of parameters drawn from the model.  Allreduces will be kicked off whenever one of these parameters receives its gradient (as opposed to when a bucket of size message_size is full).  At the end of backward(), a cleanup allreduce to catch any remaining gradients will also be performed automatically.  If allreduce_trigger_params is supplied, the message_size argument will be ignored.\n        allreduce_always_fp32 (bool, default=False):  Convert any FP16 gradients to FP32 before allreducing.  This can improve stability for widely scaled-out runs.\n        gradient_average (bool, default=True):  Option to toggle whether or not DDP averages the allreduced gradients over processes.  For proper scaling, the default value of True is recommended.\n        gradient_predivide_factor (float, default=1.0):  Allows perfoming the average of gradients over processes partially before and partially after the allreduce.  Before allreduce:  ``grads.mul_(1.0/gradient_predivide_factor)``.  After allreduce:  ``grads.mul_(gradient_predivide_factor/world size)``.  This can reduce the stress on the dynamic range of FP16 allreduces for widely scaled-out runs.\n\n    .. warning::\n        If ``gradient_average=False``, the pre-allreduce division (``grads.mul_(1.0/gradient_predivide_factor)``) will still be applied, but the post-allreduce gradient averaging (``grads.mul_(gradient_predivide_factor/world size)``) will be omitted.\n\n    \"\"\"\n\n    def __init__(self,\n                 module,\n                 message_size=10000000,\n                 delay_allreduce=False,\n                 shared_param=None,\n                 allreduce_trigger_params=None,\n                 retain_allreduce_buffers=False,\n                 allreduce_always_fp32=False,\n                 num_allreduce_streams=1,\n                 allreduce_communicators=None,\n                 gradient_average=True,\n                 gradient_predivide_factor=1.0,\n                 gradient_average_split_factor=None,\n                 prof=False):\n        super(DistributedDataParallel, self).__init__()\n\n        # Backward/forward compatibility around\n        # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and\n        # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86\n        if hasattr(dist, \"get_backend\"):\n            self._backend = dist.get_backend()\n            if hasattr(dist, \"DistBackend\"):\n                self.backend_enum_holder = dist.DistBackend\n            else:\n                self.backend_enum_holder = dist.Backend\n        else:\n            self._backend = dist._backend\n            self.backend_enum_holder = dist.dist_backend\n\n        self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False\n\n        self.prof = prof\n\n        self.allreduce_different_streams = (num_allreduce_streams > 1)\n        self.num_allreduce_streams = num_allreduce_streams\n        self.allreduce_communicators = allreduce_communicators\n        if self.allreduce_communicators:\n            assert len(allreduce_communicators[0]) == num_allreduce_streams\n            assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])\n            assert self.allreduce_different_streams\n\n        if self.allreduce_different_streams and delay_allreduce:\n            raise ValueError(\"self.allreduce_different_streams may only be used if delay_allreduce=False.\")\n\n        if shared_param is not None:\n            raise ValueError(\"shared_param is no longer supported as an option.  It was misleadingly named from the start.  It turns out overlapping communication with computation should work fine with shared parameters.  If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.\")\n\n        self.world_size = float(dist.get_world_size())\n\n        self.retain_allreduce_buffers = retain_allreduce_buffers\n        self.allreduce_always_fp32 = allreduce_always_fp32\n        self.gradient_average = gradient_average\n        self.gradient_predivide_factor = gradient_predivide_factor\n\n        self.custom_allreduce_triggers = False\n        if allreduce_trigger_params is not None:\n            if delay_allreduce:\n                raise ValueError(\"Setting allreduce_trigger_params is only valid if delay_allreduce=False.\")\n            self.custom_allreduce_triggers = True\n            self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])\n\n        self.delay_allreduce = delay_allreduce\n        self.message_size = message_size\n\n        self.main_stream = torch.cuda.current_stream()\n\n        self.bucket_streams = []\n        self.bucket_events = []\n\n        self.module = module\n\n        self._disable_allreduce = False\n\n        if self._backend == self.backend_enum_holder.NCCL:\n            for param in self.module.parameters():\n                assert param.is_cuda, \"NCCL backend only supports model parameters to be on GPU.\"\n\n        self.active_params = []\n\n        self.param_type_to_tmp_i = {\"torch.cuda.HalfTensor\" : 0,\n                                    \"torch.cuda.FloatTensor\" : 1,\n                                    \"torch.cuda.DoubleTensor\" : 2}\n\n        if multi_tensor_applier.available:\n            # TODO:  I really need to centralize the C++ backed imports\n            import amp_C\n            self.multi_tensor_scale = amp_C.multi_tensor_scale\n            self._overflow_buf = torch.cuda.IntTensor([0])\n\n        self.create_hooks()\n\n        flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )\n\n\n    def __setstate__(self, state):\n        super(DistributedDataParallel, self).__setstate__(state)\n        if self.allreduce_different_streams and delay_allreduce:\n            raise ValueError(\"self.allreduce_different_streams may only be used if delay_allreduce=False.\")\n\n        if self.delay_allreduce:\n            self.needs_refresh = True\n\n        self.bucket_streams = []\n        self.bucket_events = []\n\n\n    def __getstate__(self):\n        attrs = copy.copy(self.__dict__)\n        if self._backend != self.backend_enum_holder.NCCL:\n            del attrs['self.bucket_streams']\n            del attrs['self.bucket_events']\n            return attrs\n\n    def enable_allreduce(self):\n        self._disable_allreduce = False\n\n    def disable_allreduce(self):\n        self._disable_allreduce = True\n\n    # Broadcast rank 0's bucket structure across all processes, and have all processes\n    # regenerate their bucket structures to match.\n    def sync_bucket_structure(self):\n        # Append leftover buckets\n        for tmp_bucket in self.tmp_buckets:\n            if len(tmp_bucket) > 0:\n                self.active_i_buckets.append(tmp_bucket)\n\n        self.num_buckets = len(self.active_i_buckets)\n        self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]\n\n        info_tensor = torch.cuda.IntTensor([self.num_buckets] +\n                                           self.bucket_sizes +\n                                           list(chain(*self.active_i_buckets)))\n\n        dist.broadcast(info_tensor, 0)\n\n        info = [int(entry) for entry in info_tensor]\n\n        self.num_buckets = info[0]\n        self.bucket_sizes = info[1:self.num_buckets + 1]\n        self.buckets = [[None for _ in range(self.bucket_sizes[i])]\n                        for i in range(self.num_buckets)]\n        # Technically, active_i_buckets' work is done.  But the information is still useful to\n        # keep around.  Therefore, refresh active_i_buckets based on rank 0 as well.\n        self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]\n                                 for i in range(self.num_buckets)]\n\n        flattened_buckets = info[self.num_buckets + 1:]\n        flat_i = 0\n        for bucket_idx in range(self.num_buckets):\n            for bucket_loc in range(self.bucket_sizes[bucket_idx]):\n                param_i = flattened_buckets[flat_i]\n                self.active_i_buckets[bucket_idx][bucket_loc] = param_i\n                self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)\n                flat_i += 1\n\n\n    def create_hooks(self):\n        # Fallback hook that's only called at the end of backward.\n        # Used if you deliberately want to delay allreduces to the end, or to refresh the\n        # bucket structure that will be used to overlap communication with computation in later\n        # iterations.\n        def allreduce_params():\n            # Bucket record refresh\n            if not self.delay_allreduce:\n                if self.needs_refresh:\n                    self.sync_bucket_structure()\n\n                    self.needs_refresh = False\n\n            self.allreduce_fallback()\n\n\n        def overlapping_backward_epilogue():\n            for stream, event in zip(self.bucket_streams, self.bucket_events):\n                stream.record_event(event)\n                torch.cuda.current_stream().wait_event(event)\n\n            # Sanity checks that all the buckets were kicked off\n            if self.next_bucket != self.num_buckets:\n                raise RuntimeError(\"In epilogue, next_bucket ({}) != num_buckets ({}).  \".format(\n                                   self.next_bucket, self.num_buckets),\n                                   \"This probably indicates some buckets were not allreduced.\")\n\n            for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):\n                if actual != expected:\n                    raise RuntimeError(\"Some param buckets were not allreduced.\")\n\n\n        self.grad_accs = []\n        for param in self.module.parameters():\n            if param.requires_grad:\n                def wrapper(param):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n\n                    def allreduce_hook(*unused):\n                        if self.prof:\n                            torch.cuda.nvtx.range_push(\"allreduce_hook\")\n\n                        if not self._disable_allreduce:\n                            if self.delay_allreduce or self.needs_refresh:\n                                # TODO:  How do we want to handle multiple backward passes between\n                                # each forward, e.g., backward passes with retain_graph=True?\n                                # needs_refresh and callback_queued are both vulnerable states.\n                                if not self.delay_allreduce and self.needs_refresh:\n                                    # Use the backward pass to build the bucket structure on the fly.\n                                    active_i = self.param_id_to_active_i[id(param)]\n\n                                    # Float, half, and double tensors are grouped into buckets separately.\n                                    current_type = self.param_type_to_tmp_i[param.type()]\n\n                                    self.tmp_buckets[current_type].append(active_i)\n\n                                    ship_tmp_bucket = False\n                                    if self.custom_allreduce_triggers:\n                                        if id(param) in self.allreduce_trigger_params:\n                                            ship_tmp_bucket = True\n                                    else:\n                                        self.tmp_numels[current_type] += param.numel()\n                                        if self.tmp_numels[current_type] >= self.message_size:\n                                            ship_tmp_bucket = True\n\n                                    # To consider:  If custom_allreduce_triggers are in use, ship all\n                                    # tmp_buckets, not just tmp_buckets[current_type].\n                                    if ship_tmp_bucket:\n                                        self.active_i_buckets.append(self.tmp_buckets[current_type])\n                                        self.tmp_buckets[current_type] = []\n                                        self.tmp_numels[current_type] = 0\n\n                                if not self.callback_queued:\n                                    Variable._execution_engine.queue_callback(allreduce_params)\n                                    self.callback_queued = True\n                            else:\n                                if not self.callback_queued:\n                                    Variable._execution_engine.queue_callback(overlapping_backward_epilogue)\n                                    self.callback_queued = True\n\n                                self.comm_ready_buckets(param)\n\n                        if self.prof:\n                            torch.cuda.nvtx.range_pop()\n\n                    grad_acc.register_hook(allreduce_hook)\n                    self.grad_accs.append(grad_acc)\n\n                wrapper(param)\n\n\n    def _stream_this_bucket(self, bucket_idx):\n        if self.allreduce_different_streams:\n            return self.bucket_streams[bucket_idx%self.num_allreduce_streams]\n        else:\n            return self.bucket_streams[0]\n\n\n    def _event_this_bucket(self, bucket_idx):\n        if self.allreduce_different_streams:\n            return self.bucket_events[bucket_idx%self.num_allreduce_streams]\n        else:\n            return self.bucket_events[0]\n\n\n    def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):\n        tensor = flatten(bucket)\n\n        if force_default_stream:\n            bucket_stream = self.main_stream\n        else:\n            bucket_stream = self._stream_this_bucket(bucket_idx)\n            bucket_event = self._event_this_bucket(bucket_idx)\n            torch.cuda.current_stream().record_event(bucket_event)\n            bucket_stream.wait_event(bucket_event)\n\n        with torch.cuda.stream(bucket_stream):\n            # self.main_stream.wait_stream(torch.cuda.current_stream())\n            # torch.cuda.synchronize()\n\n            tensor_to_allreduce = tensor\n\n            if self.allreduce_always_fp32:\n                tensor_to_allreduce = tensor.float()\n\n            if self.gradient_predivide_factor != 1.0:\n                tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)\n\n            if self.allreduce_different_streams and not force_default_stream:\n                dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams])\n            else:\n                dist.all_reduce(tensor_to_allreduce)\n\n            if self.gradient_average:\n                tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)\n\n            if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:\n                tensor.copy_(tensor_to_allreduce)\n\n            if not self.retain_allreduce_buffers:\n                if multi_tensor_applier.available:\n                    multi_tensor_applier(\n                        self.multi_tensor_scale,\n                        self._overflow_buf,\n                        [unflatten(tensor, bucket), bucket],\n                        1.0)\n                else:\n                    for buf, synced in zip(bucket, unflatten(tensor, bucket)):\n                        buf.copy_(synced)\n\n            # I think we actually do need this here.  After allreduce_bucket returns, tensor will\n            # eventually go out of scope and die, at which point it could otherwise be freed for\n            # further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.\n            tensor.record_stream(bucket_stream)\n\n        return tensor\n\n\n    def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):\n        allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)\n        if self.retain_allreduce_buffers:\n            if self.allreduce_buffers[bucket_idx] is not None:\n                raise RuntimeError(\"The backward pass is attempting to replace an already-filled \"\n                                   \"allreduce buffer.  This is almost certainly an error.\")\n            self.allreduce_buffers[bucket_idx] = allreduced\n            for view, grad in zip(unflatten(allreduced, bucket), bucket):\n                grad.data = view\n            # for buf, synced in zip(bucket, unflatten(allreduced, bucket)):\n            #     buf.copy_(synced)\n\n\n    def allreduce_fallback(self):\n        for stream, event in zip(self.bucket_streams, self.bucket_events):\n            stream.record_event(event)\n            torch.cuda.current_stream().wait_event(event)\n\n        if self.retain_allreduce_buffers:\n            grads = [param.grad for param in self.module.parameters() if param.grad is not None]\n        else:\n            grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]\n\n        split_buckets = split_half_float_double(grads)\n\n        # If retain_allreduce_buffers is True and delay_allreduce is False,\n        # this will only be done during the first backward pass, ignored by the\n        # training script, and overwritten in the next forward pass.  So it's harmless.\n        if self.retain_allreduce_buffers:\n            self.allreduce_buffers = [None for _ in range(len(split_buckets))]\n\n        for i, bucket in enumerate(split_buckets):\n            allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)\n\n\n    def comm_ready_buckets(self, param):\n        # Need to do this in every hook for compatibility with Ruberry's streaming backward PR.\n        # self.reduction_stream.wait_stream(torch.cuda.current_stream())\n        if self.prof:\n            torch.cuda.nvtx.range_push(\"comm_ready_buckets\")\n\n        bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]\n\n        if self.buckets[bucket_idx][bucket_loc] is not None:\n            raise RuntimeError(\"The backward pass is attempting to replace an already-filled \"\n                               \"bucket slot.  This is almost certainly an error.\")\n\n        if self.retain_allreduce_buffers:\n            self.buckets[bucket_idx][bucket_loc] = param.grad\n        else:\n            self.buckets[bucket_idx][bucket_loc] = param.grad.data\n\n        self.buckets_ready_size[bucket_idx] += 1\n\n        if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:\n            if bucket_idx == self.next_bucket:\n                self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)\n\n                self.next_bucket += 1\n\n                # Reversing upstream's logic here, because we constructed our buckets based on\n                # the order things were received during backward.\n                if len(self.ready_buckets_not_reduced) > 0:\n                    sorted_todo = sorted(self.ready_buckets_not_reduced)\n                    for i in sorted_todo:\n                        # Nothing can be reduced now\n                        if i > self.next_bucket:\n                            break\n                        elif i == self.next_bucket:\n                            self.allreduce_maybe_retain(self.buckets[i], i)\n                            self.ready_buckets_not_reduced.remove(i)\n                            self.next_bucket += 1\n                        else:\n                            raise ValueError(\"i should always be >= next_bucket\")\n            else:\n                self.ready_buckets_not_reduced.add(bucket_idx)\n\n        if self.prof:\n            torch.cuda.nvtx.range_pop()\n\n\n    def forward(self, *inputs, **kwargs):\n        result = self.module(*inputs, **kwargs)\n\n        if self.prof:\n            torch.cuda.nvtx.range_push(\"forward pass DDP logic\")\n\n        if not self._disable_allreduce:\n            if not self.delay_allreduce:\n                param_list = [param for param in self.module.parameters() if param.requires_grad]\n\n                # Conditions under which to refresh self.record\n                # Forward has the authority to set needs_refresh to True, but only allreduce_params\n                # in backward has the authority to set needs_refresh to False.\n                # Parentheses are not necessary for correct order of operations, but make the intent clearer.\n                if ((not self.active_params) or\n                    (len(param_list) != len(self.active_params)) or\n                    any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):\n                    self.needs_refresh = True\n\n                if self.needs_refresh:\n                    self.active_i_buckets = []\n                    self.buckets = []\n                    self.tmp_buckets = [[], [], []] # [running half, float, double buckets]\n                    self.tmp_numels = [0, 0, 0]\n                    self.bucket_sizes = []\n                    self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}\n                    self.param_id_to_bucket = {}\n                    self.bucket_pgs = []\n                    self.bucket_streams = []\n                    self.bucket_events = []\n                else:\n                    # self.buckets = [[None for _ in range(self.bucket_sizes[i])]\n                    #                 for i in range(self.num_buckets)]\n                    if not self.buckets:\n                        self.buckets = [[None for _ in range(self.bucket_sizes[i])]\n                                        for i in range(self.num_buckets)]\n                    else:\n                        assert len(self.buckets) == self.num_buckets, \"len(buckets) = {}, expected {}\".format(\n                            len(self.buckets), self.num_buckets)\n                        for b, bucket in enumerate(self.buckets):\n                            assert len(bucket) == self.bucket_sizes[b], \"len(buckets[{}]) = {}, expected {})\".format(\n                                b, len(buckets[b]), self.bucket_sizes[b])\n                            for i in range(len(bucket)):\n                                bucket[i] = None\n\n                    if self.allreduce_communicators:\n                        self.bucket_pgs = self.allreduce_communicators[0]\n                        self.bucket_streams = self.allreduce_communicators[1]\n                        self.bucket_events = [torch.cuda.Event(enable_timing=False,\n                                            blocking=False) for _ in range(self.num_allreduce_streams)]\n                    else:\n                        if self.allreduce_different_streams:\n                            if not self.bucket_pgs:\n                                self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]\n                                for i, bg in enumerate(self.bucket_pgs):\n                                    print(\"rank {} created group {} with backend {}\".format(\n                                          dist.get_rank(), i, dist.get_backend(bg)))\n                        if self.allreduce_different_streams:\n                            if not self.bucket_streams:\n                                self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]\n                                self.bucket_events = [torch.cuda.Event(enable_timing=False,\n                                                      blocking=False) for _ in range(self.num_allreduce_streams)]\n                        else:\n                            if not self.bucket_streams:\n                                self.bucket_streams = [torch.cuda.Stream()]\n                                self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]\n\n                    self.buckets_ready_size = [0 for i in range(self.num_buckets)]\n                    if(self.retain_allreduce_buffers):\n                        self.allreduce_buffers = [None for _ in range(self.num_buckets)]\n                    self.next_bucket = 0\n                    self.ready_buckets_not_reduced = set()\n\n                self.active_params = param_list\n\n            self.callback_queued = False\n\n        if self.prof:\n            torch.cuda.nvtx.range_pop()\n\n        return result\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/multiproc.py",
    "content": "import torch\nimport sys\nimport subprocess\n\ndef docstring_hack():\n    \"\"\"\n    Multiproc file which will launch a set of processes locally for multi-gpu\n    usage: python -m apex.parallel.multiproc main.py ...\n    \"\"\"\n    pass\n\nargslist = list(sys.argv)[1:]\nworld_size = torch.cuda.device_count()\n\nif '--world-size' in argslist:\n    world_size = int(argslist[argslist.index('--world-size')+1])\nelse:\n    argslist.append('--world-size')\n    argslist.append(str(world_size))\n\nworkers = []\n\nfor i in range(world_size):\n    if '--rank' in argslist:\n        argslist[argslist.index('--rank')+1] = str(i)\n    else:\n        argslist.append('--rank')\n        argslist.append(str(i))\n    stdout = None if i == 0 else open(\"GPU_\"+str(i)+\".log\", \"w\")\n    print(argslist)\n    p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)\n    workers.append(p)\n\nfor p in workers:\n    p.wait()\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/optimized_sync_batchnorm.py",
    "content": "import torch\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn import functional as F\n\nimport syncbn\nfrom .optimized_sync_batchnorm_kernel import SyncBatchnormFunction\n\n\nclass SyncBatchNorm(_BatchNorm):\n    \"\"\"\n    synchronized batch normalization module extented from `torch.nn.BatchNormNd`\n    with the added stats reduction across multiple processes.\n    :class:`apex.parallel.SyncBatchNorm` is designed to work with\n    `DistributedDataParallel`.\n\n    When running in training mode, the layer reduces stats across all processes\n    to increase the effective batchsize for normalization layer. This is useful\n    in applications where batch size is small on a given process that would\n    diminish converged accuracy of the model. The model uses collective\n    communication package from `torch.distributed`.\n\n    When running in evaluation mode, the layer falls back to\n    `torch.nn.functional.batch_norm`\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``True``\n        process_group: pass in a process group within which the stats of the\n            mini-batch is being synchronized. ``None`` for using default process\n            group\n        channel_last: a boolean value that when set to ``True``, this module\n            take the last dimension of the input tensor to be the channel\n            dimension. Default: False\n\n    Examples::\n        >>> # channel first tensor\n        >>> sbn = apex.parallel.SyncBatchNorm(100).cuda()\n        >>> inp = torch.randn(10, 100, 14, 14).cuda()\n        >>> out = sbn(inp)\n        >>> inp = torch.randn(3, 100, 20).cuda()\n        >>> out = sbn(inp)\n        >>> # channel last tensor\n        >>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()\n        >>> inp = torch.randn(10, 14, 14, 100).cuda()\n    \"\"\"\n\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):\n        super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)\n        self.process_group = process_group\n        self.channel_last = channel_last\n        self.fuse_relu = fuse_relu\n\n    def _specify_process_group(self, process_group):\n        self.process_group = process_group\n\n    def _specify_channel_last(self, channel_last):\n        self.channel_last = channel_last\n\n    def forward(self, input, z = None):\n        # if input.dim() == 2, we switch to channel_last for efficient memory accessing\n        channel_last = self.channel_last if input.dim() != 2 else True\n\n        if not self.training and self.track_running_stats and not channel_last and not self.fuse_relu and z == None:\n            # fall back to pytorch implementation for inference\n            return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)\n        else:\n            exponential_average_factor = 0.0\n            if self.training and self.track_running_stats:\n                self.num_batches_tracked += 1\n                if self.momentum is None:\n                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)\n                else:\n                    exponential_average_factor = self.momentum\n            return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last, self.fuse_relu)\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/optimized_sync_batchnorm_kernel.py",
    "content": "import torch\nfrom torch.autograd.function import Function\n\nimport syncbn\nfrom apex.parallel import ReduceOp\n\nclass SyncBatchnormFunction(Function):\n\n    @staticmethod\n    def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):\n        input = input.contiguous()\n        world_size = 0\n\n        mean = None\n        var_biased = None\n        inv_std = None\n        var = None\n        out = None\n        count = None\n        if track_running_stats:\n            if channel_last:\n                count = int(input.numel()/input.size(-1))\n                mean, var_biased = syncbn.welford_mean_var_c_last(input)\n                num_channels = input.size(-1)\n            else:\n                count = int(input.numel()/input.size(1))\n                mean, var_biased = syncbn.welford_mean_var(input)\n                num_channels = input.size(1)\n\n            if torch.distributed.is_initialized():\n                if not process_group:\n                    process_group = torch.distributed.group.WORLD\n                device = mean.device\n                world_size = torch.distributed.get_world_size(process_group)\n\n                count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count)\n                combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0)\n                combined_list = [torch.empty_like(combined) for k in range(world_size)]\n                torch.distributed.all_gather(combined_list, combined, process_group)\n                combined = torch.stack(combined_list, dim=0)\n                mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)\n                count_all = count_all.view(-1)\n                mean, var, inv_std = syncbn.welford_parallel(mean_all, invstd_all, count_all.to(torch.int32), eps)\n            else:\n                device = mean.device\n                count_all = torch.cuda.IntTensor([count], device=device)\n                inv_std = 1.0 / torch.sqrt(var_biased + eps)\n                var = var_biased * (count) / (count-1)\n\n            if count == 1 and world_size < 2:\n                raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))\n\n            r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()\n            r_v_inc = var if running_variance.dtype != torch.float16 else var.half()\n            running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc\n            running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc\n        else:\n            mean = running_mean.data\n            inv_std = 1.0 / torch.sqrt(running_variance.data + eps)\n\n        ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all.to(torch.int32))\n        ctx.process_group = process_group\n        ctx.channel_last = channel_last\n        ctx.world_size = world_size\n        ctx.fuse_relu = fuse_relu\n\n        if channel_last:\n            out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)\n        else:\n            out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)\n\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_output = grad_output.contiguous()\n        # mini batch mean & var are calculated by forward path.\n        # mu = 1./N*np.sum(h, axis = 0)\n        # var = 1./N*np.sum((h-mu)**2, axis = 0)\n        saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors\n        process_group = ctx.process_group\n        channel_last = ctx.channel_last\n        world_size = ctx.world_size\n        fuse_relu = ctx.fuse_relu\n        grad_input = grad_z = grad_weight = grad_bias = None\n\n        if fuse_relu:\n            grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)\n        if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:\n            grad_z = grad_output.clone()\n\n        # TODO: update kernel to not pre_divide by item_num\n        if channel_last:\n            sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)\n        else:\n            sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)\n\n        # calculate grad_input\n        if ctx.needs_input_grad[0]:\n\n            if torch.distributed.is_initialized():\n                num_channels = sum_dy.shape[0]\n                combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)\n                torch.distributed.all_reduce(\n                    combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)\n                sum_dy, sum_dy_xmu = torch.split(combined, num_channels)\n\n            if channel_last:\n                grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)\n            else:\n                grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)\n\n        if weight is None or not ctx.needs_input_grad[2]:\n            grad_weight = None\n\n        if weight is None or not ctx.needs_input_grad[3]:\n            grad_bias = None\n\n        return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/sync_batchnorm.py",
    "content": "import torch\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn import functional as F\n\nfrom .sync_batchnorm_kernel import SyncBatchnormFunction\nfrom apex.parallel import ReduceOp\n\n\nclass SyncBatchNorm(_BatchNorm):\n    \"\"\"\n    synchronized batch normalization module extented from ``torch.nn.BatchNormNd``\n    with the added stats reduction across multiple processes.\n    :class:`apex.parallel.SyncBatchNorm` is designed to work with\n    ``DistributedDataParallel``.\n\n    When running in training mode, the layer reduces stats across all processes\n    to increase the effective batchsize for normalization layer. This is useful\n    in applications where batch size is small on a given process that would\n    diminish converged accuracy of the model. The model uses collective\n    communication package from ``torch.distributed``.\n\n    When running in evaluation mode, the layer falls back to\n    ``torch.nn.functional.batch_norm``.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``True``\n\n    Example::\n\n        >>> sbn = apex.parallel.SyncBatchNorm(100).cuda()\n        >>> inp = torch.randn(10, 100, 14, 14).cuda()\n        >>> out = sbn(inp)\n        >>> inp = torch.randn(3, 100, 20).cuda()\n        >>> out = sbn(inp)\n    \"\"\"\n\n    warned = False\n\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):\n        if channel_last == True:\n            raise AttributeError(\"channel_last is not supported by primitive SyncBatchNorm implementation. Try install apex with `--cuda_ext` if channel_last is desired.\")\n\n        if not SyncBatchNorm.warned:\n            if hasattr(self, \"syncbn_import_error\"):\n                print(\"Warning:  using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext.  The exception raised when attempting to import the cuda backend was: \", self.syncbn_import_error)\n            else:\n                print(\"Warning:  using Python fallback for SyncBatchNorm\")\n            SyncBatchNorm.warned = True\n\n        super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)\n        self.process_group = process_group\n\n    def _specify_process_group(self, process_group):\n        self.process_group = process_group\n\n    def forward(self, input):\n        torch.cuda.nvtx.range_push(\"sync_bn_fw_with_mean_var\")\n        mean = None\n        var = None\n        cast = None\n        out = None\n\n        # casting to handle mismatch input type to layer type\n        if self.running_mean is not None:\n            if self.running_mean.dtype != input.dtype:\n                input = input.to(self.running_mean.dtype)\n                cast = input.dtype\n        elif self.weight is not None:\n            if self.weight.dtype != input.dtype:\n                input = input.to(self.weight.dtype)\n                cast = input.dtype\n\n        if not self.training and self.track_running_stats:\n            # fall back to pytorch implementation for inference\n            torch.cuda.nvtx.range_pop()\n            out = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)\n        else:\n            process_group = self.process_group\n            world_size = 1\n            if not self.process_group:\n                process_group = torch.distributed.group.WORLD\n            self.num_batches_tracked += 1\n            with torch.no_grad():\n                channel_first_input = input.transpose(0, 1).contiguous()\n                squashed_input_tensor_view = channel_first_input.view(\n                    channel_first_input.size(0), -1)\n                # total number of data points for each variance entry. Used to calculate unbiased variance estimate\n                m = None\n                local_m = float(squashed_input_tensor_view.size()[1])\n                local_mean = torch.mean(squashed_input_tensor_view, 1)\n                local_sqr_mean = torch.pow(\n                    squashed_input_tensor_view, 2).mean(1)\n                if torch.distributed.is_initialized():\n                    world_size = torch.distributed.get_world_size(process_group)\n                    torch.distributed.all_reduce(\n                        local_mean, ReduceOp.SUM, process_group)\n                    mean = local_mean / world_size\n                    torch.distributed.all_reduce(\n                        local_sqr_mean, ReduceOp.SUM, process_group)\n                    sqr_mean = local_sqr_mean / world_size\n                    m = local_m * world_size\n                else:\n                    m = local_m\n                    mean = local_mean\n                    sqr_mean = local_sqr_mean\n                # var(x) = E (( x - mean_x ) ** 2)\n                #        = 1 / N * sum ( x - mean_x ) ** 2\n                #        = 1 / N * sum (x**2) - mean_x**2\n                var = sqr_mean - mean.pow(2)\n\n                if self.running_mean is not None:\n                    self.running_mean = self.momentum * mean + \\\n                        (1 - self.momentum) * self.running_mean\n                if self.running_var is not None:\n                    # as noted by the paper, we used unbiased variance estimate of the mini-batch\n                    # Var[x] = m / (m-1) * Eb (sample_variance)\n                    self.running_var = m / \\\n                        (m-1) * self.momentum * var + \\\n                        (1 - self.momentum) * self.running_var\n            torch.cuda.nvtx.range_pop()\n            out = SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size)\n        return out.to(cast)\n"
  },
  {
    "path": "KoSimCSE/apex/parallel/sync_batchnorm_kernel.py",
    "content": "import torch\nfrom torch.autograd.function import Function\n\nfrom apex.parallel import ReduceOp\n\n\nclass SyncBatchnormFunction(Function):\n\n    @staticmethod\n    def forward(ctx, input, weight, bias, running_mean, running_variance, eps, process_group, world_size):\n        torch.cuda.nvtx.range_push(\"sync_BN_fw\")\n        # transpose it to channel last to support broadcasting for input with different rank\n        c_last_input = input.transpose(1, -1).contiguous().clone()\n\n        ctx.save_for_backward(c_last_input, weight, bias,\n                              running_mean, running_variance)\n        ctx.eps = eps\n        ctx.process_group = process_group\n        ctx.world_size = world_size\n\n        c_last_input = (c_last_input - running_mean) / \\\n            torch.sqrt(running_variance + eps)\n\n        if weight is not None:\n            c_last_input = c_last_input * weight\n        if bias is not None:\n            c_last_input = c_last_input + bias\n\n        torch.cuda.nvtx.range_pop()\n        return c_last_input.transpose(1, -1).contiguous().clone()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        torch.cuda.nvtx.range_push(\"sync_BN_bw\")\n        # mini batch mean & var are calculated by forward path.\n        # mu = 1./N*np.sum(h, axis = 0)\n        # var = 1./N*np.sum((h-mu)**2, axis = 0)\n        c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors\n\n        eps = ctx.eps\n        process_group = ctx.process_group\n        world_size = ctx.world_size\n        grad_input = grad_weight = grad_bias = None\n        num_features = running_mean.size()[0]\n\n        # transpose it to channel last to support broadcasting for input with different rank\n        torch.cuda.nvtx.range_push(\"carilli field\")\n        c_last_grad = grad_output.transpose(1, -1).contiguous()\n        # squash non-channel dimension so we can easily calculate mean\n        c_grad = c_last_grad.view(-1, num_features).contiguous()\n        torch.cuda.nvtx.range_pop()\n\n        # calculate grad_input\n        if ctx.needs_input_grad[0]:\n            # dh = gamma * (var + eps)**(-1. / 2.) * (dy - np.mean(dy, axis=0)\n            #     - (h - mu) * (var + eps)**(-1.0) * np.mean(dy * (h - mu), axis=0))\n            mean_dy = c_grad.mean(0)\n            mean_dy_xmu = (c_last_grad * (c_last_input -\n                                          running_mean)).view(-1, num_features).mean(0)\n            if torch.distributed.is_initialized():\n                torch.distributed.all_reduce(\n                    mean_dy, ReduceOp.SUM, process_group)\n                mean_dy = mean_dy / world_size\n                torch.distributed.all_reduce(\n                    mean_dy_xmu, ReduceOp.SUM, process_group)\n                mean_dy_xmu = mean_dy_xmu / world_size\n            c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (\n                running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)\n            if weight is not None:\n                c_last_grad_input.mul_(weight)\n            grad_input = c_last_grad_input.transpose(1, -1).contiguous()\n\n        # calculate grad_weight\n        grad_weight = None\n        if weight is not None and ctx.needs_input_grad[1]:\n            # dgamma = np.sum((h - mu) * (var + eps)**(-1. / 2.) * dy, axis=0)\n            grad_weight = ((c_last_input - running_mean) / torch.sqrt(\n                running_variance + eps) * c_last_grad).view(-1, num_features).sum(0)\n\n        # calculate grad_bias\n        grad_bias = None\n        if bias is not None and ctx.needs_input_grad[2]:\n            # dbeta = np.sum(dy, axis=0)\n            grad_bias = c_grad.sum(0)\n\n        torch.cuda.nvtx.range_pop()\n        return grad_input, grad_weight, grad_bias, None, None, None, None, None\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/FAQs.md",
    "content": "1. How do I intercept the Adam optimizer in APEX ?\n\n\t```python\n\tfrom apex import pyprof\n\timport fused_adam_cuda\n\tpyprof.nvtx.wrap(fused_adam_cuda, 'adam')\n\t```\n\n2. If you are using JIT and/or AMP, the correct initialization sequence is\n\t1. Let any JIT to finish.\n\t2. Initlialize pyprof `pyprof.nvtx.init()`.\n\t3. Initialize AMP.\n\n3. How do I profile with `torch.distributed.launch` ?\n\n\t```python\n\tnvprof -f -o net%p.sql \\\n\t\t--profile-from-start off \\\n\t\t--profile-child-processes \\\n\t\tpython -m torch.distributed.launch net.py\n\t```\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/README.md",
    "content": "## PyProf - PyTorch Profiling tool\n\n### What does this tool do?                                                                                                                                                                                                                  \n\nAnalyzing the performance of deep neural networks is hard. Getting kernels out of [NvProf]([https://developer.nvidia.com/nvidia-visual-profiler](https://developer.nvidia.com/nvidia-visual-profiler)) or [NSight Compute]([https://developer.nvidia.com/nsight-compute](https://developer.nvidia.com/nsight-compute)) provides some generic kernel name and its execution time, but not detailed information regarding the following:\n\n - Which layer launched it: e.g. the association of `ComputeOffsetsKernel` with a concrete PyTorch layer or API is not obvious.\n - What the tensor dimensions and precision were: without knowing the tensor dimensions and precision, it's impossible to reason about whether the actual (silicon) kernel time is close to maximum performance of such a kernel on the GPU. Knowing the tensor dimensions and precision, we can figure out the FLOPs and bandwidth required by a layer, and then determine how close to maximum performance the kernel is for that operation.\n - Forward-backward correlation: currently it's very hard to determine what the forward pass step was that resulted in the particular weight and data gradients (wgrad, dgrad), which makes it difficult to determine the tensor dimensions required by these backprop steps to assess their performance.\n - Did the kernel use [Tensor Cores]([https://www.youtube.com/watch?v=yyR0ZoCeBO8](https://www.youtube.com/watch?v=yyR0ZoCeBO8))?\n - Which line in the user's code resulted in launching this particular kernel (program trace)?\n\nPyProf addresses all of the issues above by:\n\n 1. Instrumenting PyTorch operations to capture the tensor dimensions and precision using [NVTX](https://devblogs.nvidia.com/cuda-pro-tip-generate-custom-application-profile-timelines-nvtx). This information is recorded at profile capture time, e.g. using [NvProf](https://developer.nvidia.com/nvidia-visual-profiler).\n 2. Querying the record produced by the profiler to correlate the kernel name and duration with PyTorch API/layer name, tensor dimensions, tensor precision, as well as calculating FLOPs and bandwidth for common operations. In addition, extra information from the profile is added for use by CUDA professionals, such as CUDA launch parameters (block/grid dimensions).\n\nRegarding FLOP and bandwidth implementations, these are usually quite straightforward. For example, for matrices A<sub>MxK</sub> and B<sub>KxN</sub>, the FLOP count for a matrix multiplication is 2 * M * N * K, and bandwidth is M * K + N * K + M * N. Note that these numbers are based on the algorithm, not the actual performance of the specific kernel. For more details, see NVIDIA's [Deep Learning Performance Guide](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html).\n\nArmed with such information, the user can determine various issues to help them tune the network. For instance, according to the [Tensor Core Performance Guide]([https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html)), the M, N and K dimensions that result in Tensor Core usage need to be divisible by 8. In fact, PyProf comes with a flag that lets the user obtain information regarding whether Tensor Cores were used by the kernel. Other useful information might include knowing that a particular kernel did not exploit much thread parallelism, as determined by the grid/block dimensions. Since many PyTorch kernels are open-source (or even custom written by the user, as in [CUDA Extensions]([https://pytorch.org/tutorials/advanced/cpp_extension.html](https://pytorch.org/tutorials/advanced/cpp_extension.html))), this provides the user with information that helps root cause performance issues and prioritize optimization work.\n\n\n### How to get started?\n\n1. Add the following lines to your PyTorch network:\n\n    ```python\n    import torch.cuda.profiler as profiler\n    from apex import pyprof\n    pyprof.nvtx.init()\n    ```\n\n    Run the training/inference loop with the [PyTorch's NVTX context manager](https://pytorch.org/docs/stable/_modules/torch/autograd/profiler.html#emit_nvtx)\n    `with torch.autograd.profiler.emit_nvtx()`. Optionally, you can\n    use `profiler.start()` and `profiler.stop()` to pick an iteration\n    (say after warm-up) for which you would like to capture data.\n    Here's an example:\n\n    ```python\n    iters = 500\n    iter_to_capture = 100\n\n    # Define network, loss function, optimizer etc.\n\n    # PyTorch NVTX context manager\n    with torch.autograd.profiler.emit_nvtx():\n\n        for iter in range(iters):\n\n            if iter == iter_to_capture:\n                profiler.start()\n\n            output = net(images)\n            loss = criterion(output, labels)\n            loss.backward()\n            optimizer.step()\n\n            if iter == iter_to_capture:\n                profiler.stop()\n    ```\n\n2. Run NVprof to generate a SQL (NVVP) file. This file can be opened with NVVP, as usual.\n    ```sh\n    # If you used profiler.start() and profiler.stop() in net.py\n    nvprof -f -o net.sql --profile-from-start off -- python net.py\n\n    # Profile everything\n    nvprof -f -o net.sql -- python net.py\n    ```\n\n**Note:** if you're experiencing issues with hardware counters and you get a message such as `**_ERR_NVGPUCTRPERM The user running <tool_name/application_name> does not have permission to access NVIDIA GPU Performance Counters on the target device_**`, please follow the steps described in [Hardware Counters](#hardware-counters).\n\n3. Run parser on the SQL file. The output is an ASCII file. Each line\nis a python dictionary which contains information about the kernel name,\nduration, parameters etc. This file can be used as input to other custom\nscripts as well.\n\n    ```sh\n    python -m apex.pyprof.parse net.sql > net.dict\n    ```\n\n4. Run the profiler. The input is the python dictionary created above. The tool can produce a CSV output, a columnated output (similar to `column -t` for terminal readability) and a space separated output (for post processing by AWK for instance). The tool produces 20 columns of information for every GPU kernel but you can select a subset of columns using the `-c` flag. Note that a few columns might have the value \"na\" implying either its a work in progress or the tool was unable to extract that information. Assuming the directory is `prof`, here are a few examples of how to use `prof.py`.\n\n    ```sh\n\t# Print usage and help. Lists all available output columns.\n    python -m apex.pyprof.prof -h\n\n\t# Columnated output of width 150 with some default columns.\n    python -m apex.pyprof.prof -w 150 net.dict\n\n\t# CSV output.\n    python -m apex.pyprof.prof --csv net.dict\n\n\t# Space seperated output.\n    python -m apex.pyprof.prof net.dict\n\n\t# Columnated output of width 130 with columns index,direction,kernel name,parameters,silicon time.\n    python -m apex.pyprof.prof -w 130 -c idx,dir,kernel,params,sil net.dict\n\n\t# CSV output with columns index,direction,kernel name,parameters,silicon time.\n    python -m apex.pyprof.prof --csv -c idx,dir,kernel,params,sil net.dict\n\n\t# Space separated output with columns index,direction,kernel name,parameters,silicon time.\n    python -m apex.pyprof.prof -c idx,dir,kernel,params,sil net.dict\n\n\t# Input redirection.\n    python -m apex.pyprof.prof < net.dict\n    ```\n\n5. Profile-guided optimization\n\nIf kernels that do matrix multiplication/GEMM or convolution use half precision (fp16) data but do not use Tensor Cores (the TC column in the profile analysis output doesn't show a \"1\"), one can follow some basic steps to increase the likelihood that a Tensor Core-compatible kernel will be chosen. For example, for GEMMs, M, N and K should be divisible by 8, and for convolutions, the number of input and output channels shuold be divisible by 8. For more information, see detailed Tensor Core guides such as:\n- Blog Post: [Tips for Optimizing GPU Performance Using Tensor Cores](https://devblogs.nvidia.com/optimizing-gpu-performance-tensor-cores/)\n- GTC Talk: [Tensor Core Deep Learning Performance Guide](https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf)\n\nFor both Tensor Core and non-Tensor Core Deep Learning performance optimization tips, see NVIDIA's [Deep Learning Performance Guide](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html).\n\n### TODOs\n1. The support for conv transpose is currently missing.\n2. PyProf currently works only with NvProf, but Nsight Compute support will be added in the future.\n\n### Example\n\n1. Run `nvprof` on the LeNet model in `examples/lenet.py`. This will output a SQL file called `net.sql`.\n\n```sh\nnvprof -f -o net.sql --profile-from-start off -- python examples/lenet.py\n```\n\n**Note**: DO NOT add --analysis-metrics since that will change which table nvprof writes the kernels to (`CUPTI_ACTIVITY_KIND_KERNEL` instead of the usual `CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL`). Support for running with metrics may be added in the future.\n\nIf you don't care about a full correlation analysis and you'd just like to view the timeline with detailed NVTX annotations, you can do so, e.g. in the NVIDIA Visual Profiler (NVVP). For example, you can call `nvvp net.sql` to view the annotated timeline.\n\n2. Run the `parse.py` script on `net.sql` to extract kernel and runtime information and\nsave it as `net.dict`.\n\n```sh\npython -m apex.pyprof.parse net.sql > net.dict\n```\n\nThis will produce a text file, which can be parsed by any external tool, but it can also be directly read one line at a time by Python by calling `eval` on the line being read. \n\n**Note: you do not need to process this output manually.**  Here the output is just shown as an example of modularity - you can process the raw data yourself, or let the next step enrich the information further and dump a CSV.\n\nThe output of this step will look as follows. Note that the dictionary has a lot more keys than the ones shown in the example.\n\n```\n>>> with open('torchvision.resnet50.adam.64.dict') as f:\n...     for line in f:\n...         d = eval(line)\n...         print(d['kShortName'], d['op'], d['kDuration'], d['block'], d['grid'], d['device'], d['stream'], d['trace'])\n... \nnchwToNhwc3To4Kernel ['conv2d'] 376324 (256, 1, 1) (1568, 1, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\ngeneric4Channel_kernel ['conv2d'] 10720 (512, 1, 1) (19, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\nfirst_layer_fwd_kernel ['conv2d'] 411204 (128, 1, 1) (2, 7, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\nnhwcToNchwKernel ['conv2d'] 342371 (256, 1, 1) (392, 2, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195']\nelementwise_kernel ['__iadd__'] 2816 (128, 1, 1) (1, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:196']\nbatch_norm_collect_statistics_kernel ['batch_norm', 'batch_norm'] 929513 (512, 1, 1) (64, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:196']\n```\n\n3. Run the `prof.py` script on `net.dict` to summarize the results into a CSV file, or to display the pretty-printed results on the screen. This step processes the raw output from step 2 to generate a nice output, but it also adds a lot of extra useful information inferred from the previous step, such as:\n- FLOPs\n- bandwidth (bytes in and out of GPU DRAM)\n- tensor core usage\n\n```sh\npython -m apex.pyprof.prof --csv net.dict > results.csv\n```\n\nYou can choose which columns you'd like to display. Here's a list from calling `python -m apex.pyprof.prof -h`:\n\n```\n              idx:      Index\n              seq:      PyTorch Sequence Id\n              altseq:   PyTorch Alternate Sequence Id\n              tid:      Thread Id\n              layer:    User annotated NVTX string (can be nested)\n              trace:    Function Call Trace\n              dir:      Direction\n              sub:      Sub Sequence Id\n              mod:      Module\n              op:       Operation\n              kernel:   Kernel Name\n              params:   Parameters\n              sil:      Silicon Time (in ns)\n              tc:       Tensor Core Usage\n              device:   GPU Device Id\n              stream:   Stream Id\n              grid:     Grid Dimensions\n              block:    Block Dimensions\n              flops:    Floating point ops (FMA = 2 FLOPs)\n              bytes:    Number of bytes in and out of DRAM\n```              \n\nLet's have a look at the pretty-printed output:\n```\npython -m apex.pyprof.prof -w 100 -c kernel,op,sil,tc,flops,bytes,device,stream,block,grid torchvision.resnet50.adam.64.dict\n\nKernel              Op              Sil(ns)    TC FLOPs        Bytes        Dev Str Block        Grid         \nelementwise_kernel  relu                381028 -      51380224    205520896   0   7 512,1,1      100352,1,1   \nvolta_fp16_s884cudn conv2d              160002 1    1644167168     51388416   0   7 256,1,1      784,1,1      \nelementwise_kernel  relu                 96545 -      12845056     51380224   0   7 512,1,1      25088,1,1    \nvolta_fp16_s884cudn conv2d              346083 1    6576668672    128483328   0   7 256,1,1      784,2,1      \n```\n\nNot using the pretty-print width (`-w`) option and adding `--csv` results in a CSV output instead:\n\n```\npython -m apex.pyprof.prof --csv -c kernel,mod,op,dir,sil,tc,flops,bytes,device,stream,block,grid torchvision.resnet50.adam.64.dict\n\n\"Kernel\",\"Module\",\"Op\",\"Direction\",\"Sil(ns)\",\"TC\",\"FLOPs\",\"Bytes\",\"Device\",\"Stream\",\"Block\",\"Grid\"\n\"nchwToNhwc3To4Kernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"376324\",\"-\",\"0\",\"0\",\"0\",\"7\",\"256,1,1\",\"1568,1,64\"\n\"generic4Channel_kernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"10720\",\"-\",\"0\",\"0\",\"0\",\"7\",\"512,1,1\",\"19,1,1\"\n\"first_layer_fwd_kernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"411204\",\"-\",\"0\",\"0\",\"0\",\"7\",\"128,1,1\",\"2,7,64\"\n\"nhwcToNchwKernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"342371\",\"-\",\"0\",\"0\",\"0\",\"7\",\"256,1,1\",\"392,2,64\"\n\"elementwise_kernel\",\"Tensor\",\"__iadd__\",\"fprop\",\"2816\",\"-\",\"1.0\",\"8\",\"0\",\"7\",\"128,1,1\",\"1,1,1\"\n\"batch_norm_collect_statistics_kernel\",\"torch.nn.functional\",\"batch_norm\",\"fprop\",\"929513\",\"-\",\"411041792\",\"411041792\",\"0\",\"7\",\"512,1,1\",\"64,1,1\"\n\"batch_norm_transform_input_kernel\",\"torch.nn.functional\",\"batch_norm\",\"fprop\",\"377539\",\"-\",\"411041792\",\"411041792\",\"0\",\"7\",\"512,1,1\",\"64,64,1\"\n\"elementwise_kernel\",\"torch.nn.functional\",\"relu\",\"fprop\",\"381028\",\"-\",\"51380224\",\"205520896\",\"0\",\"7\",\"512,1,1\",\"100352,1,1\"\n\"MaxPoolForward\",\"torch.nn.functional\",\"max_pool2d\",\"fprop\",\"406531\",\"-\",\"0\",\"0\",\"0\",\"7\",\"256,1,1\",\"50176,1,1\"\n\"cudnn::gemm::computeOffsetsKernel\",\"torch.nn.functional\",\"conv2d\",\"fprop\",\"2464\",\"-\",\"0\",\"0\",\"0\",\"7\",\"128,1,1\",\"25,1,1\"\n```\n\n### Hardware Counters\n\nProfiling GPU workloads may require access to [hardware performance counters]([https://en.wikipedia.org/wiki/Hardware_performance_counter](https://en.wikipedia.org/wiki/Hardware_performance_counter)). Due to a [fix](https://nvidia.custhelp.com/app/answers/detail/a_id/4738) in recent NVIDIA drivers addressing [CVE‑2018‑6260](https://nvd.nist.gov/vuln/detail/CVE-2018-6260), the hardware counters are disabled by default, and require elevated privileges to be enabled again. If you're using a recent driver, you may see the following message when trying to run nvprof:\n\n```**_ERR_NVGPUCTRPERM The user running <tool_name/application_name> does not have permission to access NVIDIA GPU Performance Counters on the target device._**```\n\nFor details, see [here](https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters).\n\n_Permanent solution_\n\nFollow the steps [here]([https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters](https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters)). The current steps for Linux are:\n```\nsudo systemctl isolate multi-user\nsudo modprobe -r nvidia_uvm nvidia_drm nvidia_modeset nvidia-vgpu-vfio nvidia\nsudo modprobe nvidia NVreg_RestrictProfilingToAdminUsers=0\nsudo systemctl isolate graphical\n```\nThe above steps should result in a permanent change.\n\n_Temporary solution_\n\nWhen running on bare metal, you can run nvprof with `sudo`.\n\nIf you're running in a Docker image, you can temporarily elevate your privileges with one of the following (oldest to newest syntax):\n<pre>\nnvidia-docker run <b>--privileged</b>\ndocker run --runtime nvidia <b>--privileged</b>\ndocker run --gpus all <b>--privileged<b>\n</pre>\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/__init__.py",
    "content": "import warnings\n\nfrom . import nvtx, prof\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/.gitignore",
    "content": "__pycache__\n*.sql\n*.dict\n*.csv\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/apex/README.md",
    "content": "This directory has examples of how to use `pyprof` with APEX extensions e.g. `fused_adam_cuda` and `fused_layer_norm_cuda`.\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/apex/fused_adam.py",
    "content": "import torch\nimport fused_adam_cuda\nfrom apex.optimizers import FusedAdam, FP16_Optimizer\nfrom apex import pyprof\n\npyprof.nvtx.init()\npyprof.nvtx.wrap(fused_adam_cuda, 'adam')\n\nmodel = torch.nn.Linear(10, 20).cuda().half()\ncriterion = torch.nn.CrossEntropyLoss().cuda()\noptimizer = FusedAdam(model.parameters())\noptimizer = FP16_Optimizer(optimizer)\n\nx = torch.ones(32, 10).cuda().half()\ntarget = torch.empty(32, dtype=torch.long).random_(20).cuda()\ny = model(x)\nloss = criterion(y, target)\noptimizer.zero_grad()\nloss.backward()\noptimizer.step()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/apex/fused_layer_norm.py",
    "content": "import torch\nimport fused_layer_norm_cuda\nfrom apex.normalization import FusedLayerNorm\nfrom apex import pyprof\n\npyprof.nvtx.init()\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'forward')\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'backward')\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'forward_affine')\npyprof.nvtx.wrap(fused_layer_norm_cuda, 'backward_affine')\n\ninput = torch.randn(20, 5, 10, 10).cuda()\n\n# With Learnable Parameters\nm = FusedLayerNorm(input.size()[1:]).cuda()\noutput = m(input)\n\n# Without Learnable Parameters\nm = FusedLayerNorm(input.size()[1:], elementwise_affine=False).cuda()\noutput = m(input)\n\n# Normalize over last two dimensions\nm = FusedLayerNorm([10, 10]).cuda()\noutput = m(input)\n\n# Normalize over last dimension of size 10\nm = FusedLayerNorm(10).cuda()\noutput = m(input)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/apex/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo $sql python $f\"\n\tnvprof -fo $sql python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t$prof -w 130 $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/custom_func_module/README.md",
    "content": "This directory has examples which show how to intercept (monkey patch) custom functions and modules with `pyprof`. No changes are required in `pyprof/parse`, however, users can add support for bytes and flops calculation for custom functions and modules in `pyprof/prof` by extending the `OperatorLayerBase` class.\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/custom_func_module/custom_function.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n#Initialize pyprof\npyprof.nvtx.init()\n\nclass Foo(torch.autograd.Function):\n\t@staticmethod\n\tdef forward(ctx, in1, in2):\n\t\tout = in1 + in2\t\t#This could be a custom C/C++ function.\n\t\treturn out\n\n\t@staticmethod\n\tdef backward(ctx, grad):\n\t\tin1_grad = grad\t\t#This could be a custom C/C++ function.\n\t\tin2_grad = grad\t\t#This could be a custom C/C++ function.\n\t\treturn in1_grad, in2_grad\n\n#Hook the forward and backward functions to pyprof\npyprof.nvtx.wrap(Foo, 'forward')\npyprof.nvtx.wrap(Foo, 'backward')\n\nfoo = Foo.apply\n\nx = torch.ones(4,4).cuda()\ny = torch.ones(4,4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x,y)\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/custom_func_module/custom_module.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\npyprof.nvtx.init()\n\nclass Foo(torch.nn.Module):\n    def __init__(self, size):\n        super(Foo, self).__init__()\n        self.n = torch.nn.Parameter(torch.ones(size))\n        self.m = torch.nn.Parameter(torch.ones(size))\n\n    def forward(self, input):\n        return self.n*input + self.m\n\n#Hook the forward function to pyprof\npyprof.nvtx.wrap(Foo, 'forward')\n\nfoo = Foo(4)\nfoo.cuda()\nx = torch.ones(4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x)\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/custom_func_module/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo $sql python $f\"\n\tnvprof -fo $sql python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t$prof -w 130 $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/imagenet/imagenet.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nExample to run pyprof with imagenet models.\n\"\"\"\n\nimport sys\nimport torch\nimport torch.nn as nn\nimport torchvision.models as models\nimport torch.cuda.profiler as profiler\nimport argparse\n\nfrom apex import pyprof\nfrom apex.optimizers import FusedAdam\n\ndef parseArgs():\n\tparser = argparse.ArgumentParser(prog=sys.argv[0], description=\"Run popular imagenet models.\")\n\n\tparser.add_argument(\"-m\",\n\t\ttype=str,\n\t\tdefault=\"resnet50\",\n\t\tchoices=[\"alexnet\", \"densenet121\", \"densenet161\", \"densenet169\", \"densenet201\", \"googlenet\", \"mnasnet0_5\", \"mnasnet0_75\", \"mnasnet1_0\", \"mnasnet1_3\", \"mobilenet_v2\", \"resnet18\", \"resnet34\", \"resnet50\", \"resnet101\", \"resnet152\", \"resnext50_32x4d\", \"resnext101_32x8d\", \"wide_resnet50_2\", \"wide_resnet101_2\", \"shufflenet_v2_x0_5\", \"shufflenet_v2_x1_0\", \"shufflenet_v2_x1_5\", \"shufflenet_v2_x2_0\", \"squeezenet1_0\", \"squeezenet1_1\", \"vgg11\", \"vgg11_bn\", \"vgg13\", \"vgg13_bn\", \"vgg16\", \"vgg16_bn\", \"vgg19\", \"vgg19_bn\", \"inception_v3\"],\n\t\thelp=\"Model.\")\n\n\tparser.add_argument(\"-b\",\n\t\ttype=int,\n\t\tdefault=32,\n\t\thelp=\"Batch size.\")\n\n\tparser.add_argument(\"-o\",\n\t\ttype=str,\n\t\tdefault=\"adam\",\n\t\tchoices=[\"adam\", \"sgd\"],\n\t\thelp=\"Optimizer.\")\n\n\targs = parser.parse_args()\n\treturn args\n\nd = {\n\t\"alexnet\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"densenet121\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"densenet161\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"densenet169\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"densenet201\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"googlenet\":\t\t\t{'H': 224, 'W': 224, 'opts': {'aux_logits': False}},\n\n\t\"mnasnet0_5\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"mnasnet0_75\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"mnasnet1_0\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"mnasnet1_3\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"mobilenet_v2\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"resnet18\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet34\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet50\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet101\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnet152\":\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"resnext50_32x4d\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"resnext101_32x8d\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"wide_resnet50_2\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"wide_resnet101_2\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"shufflenet_v2_x0_5\": \t{'H': 224, 'W': 224, 'opts': {}},\n\t\"shufflenet_v2_x1_0\": \t{'H': 224, 'W': 224, 'opts': {}},\n\t\"shufflenet_v2_x1_5\": \t{'H': 224, 'W': 224, 'opts': {}},\n\t\"shufflenet_v2_x2_0\":\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"squeezenet1_0\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"squeezenet1_1\":\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"vgg11\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg11_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg13\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg13_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg16\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg16_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg19\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\t\"vgg19_bn\":\t\t\t\t{'H': 224, 'W': 224, 'opts': {}},\n\n\t\"inception_v3\":\t\t\t{'H': 299, 'W': 299, 'opts': {'aux_logits': False}},\n\t}\n\ndef main():\n\targs = parseArgs()\n\n\tpyprof.nvtx.init()\n#\tpyprof.nvtx.wrap(fused_adam_cuda, 'adam')\n\n\tN = args.b\n\tC = 3\n\tH = d[args.m]['H']\n\tW = d[args.m]['W']\n\topts = d[args.m]['opts']\n\tclasses = 1000\n\n\tnet = getattr(models, args.m)\n\tnet = net(**opts).cuda().half()\n\tnet.train()\n\n\tx = torch.rand(N, C, H, W).cuda().half()\n\ttarget = torch.empty(N, dtype=torch.long).random_(classes).cuda()\n\n\tcriterion = nn.CrossEntropyLoss().cuda()\n\tif (args.o == \"sgd\"):\n\t\toptimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n\telif (args.o == \"adam\"):\n\t\toptimizer = FusedAdam(net.parameters())\n\telse:\n\t\tassert False\n\n\t#Warm up without profiler\n\tfor i in range(2):\n\t\toutput = net(x)\n\t\tloss = criterion(output, target)\n\t\toptimizer.zero_grad()\n\t\tloss.backward()\n\t\toptimizer.step()\n\n\twith torch.autograd.profiler.emit_nvtx():\n\t\tprofiler.start()\n\t\toutput = net(x)\n\t\tloss = criterion(output, target)\n\t\toptimizer.zero_grad()\n\t\tloss.backward()\n\t\toptimizer.step()\n\t\tprofiler.stop()\n\nif __name__ == \"__main__\":\n\tmain()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/imagenet/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python -m apex.pyprof.parse\"\nprof=\"python -m apex.pyprof.prof\"\n\nfor net in \"resnet50\"\ndo\n\tfor optim in adam sgd\n\tdo\n\t\tfor batch in 32 64\n\t\tdo\n\t\t\tbase=\"torchvision\".$net.$optim.$batch\n\t\t\tsql=$base.sql\n\t\t\tdict=$base.dict\n\n\t\t\t#NVprof\n\t\t\techo \"nvprof -fo $sql --profile-from-start off python imagenet.py -m ${net} -o $optim -b $batch\"\n\t\t\tnvprof -fo $sql --profile-from-start off python imagenet.py -m ${net} -o $optim -b $batch\n\n\t\t\t#Parse\n\t\t\techo $parse $sql\n\t\t\t$parse $sql > $dict\n\n\t\t\t#Prof\n\t\t\techo $prof $dict\n\t\t\t$prof -w 130 $dict\n#\t\t\t\\rm $sql $dict\n\t\tdone\n\tdone\ndone\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/jit/README.md",
    "content": "*As of this writing, these examples do not work\nbecause of changes being proposed in PyTorch.*\n\nThere are two ways to use PyTorch JIT\n - Scripting\n - Tracing\n\nIn addition, we can JIT a\n - Stand alone function\n - Class / class method\n\nThis directory has an example for each of the 4 cases.\nIntercepting (monkey patching) JITted code has a few extra steps,\nwhich are explained through comments.\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/jit/jit_script_function.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\n#The following creates an object \"foo\" of type ScriptModule\n#The new object has a function called \"forward\"\n\n@torch.jit.script\ndef foo(x, y):\n\treturn torch.sigmoid(x) + y\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Assign a name to the object \"foo\"\nfoo.__name__ = \"foo\"\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(foo, 'forward')\n\nx = torch.zeros(4,4).cuda()\ny = torch.ones(4,4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x, y)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/jit/jit_script_method.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\nclass Foo(torch.jit.ScriptModule):\n    def __init__(self, size):\n        super(Foo, self).__init__()\n        self.n = torch.nn.Parameter(torch.ones(size))\n        self.m = torch.nn.Parameter(torch.ones(size))\n\n    @torch.jit.script_method\n    def forward(self, input):\n        return self.n*input + self.m\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(Foo, 'forward')\n\nfoo = Foo(4)\nfoo.cuda()\nx = torch.ones(4).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = foo(x)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/jit/jit_trace_function.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\ndef foo(x, y):\n\treturn torch.sigmoid(x) + y\n\nx = torch.zeros(4,4).cuda()\ny = torch.ones(4,4).cuda()\n\n#JIT the function using tracing\n#This returns an object of type ScriptModule with a forward method.\ntraced_foo = torch.jit.trace(foo, (x,y))\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Assign a name to the object \"traced_foo\"\ntraced_foo.__dict__['__name__'] = \"foo\"\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(traced_foo, 'forward')\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = traced_foo(x, y)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/jit/jit_trace_method.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.cuda.profiler as profiler\nfrom apex import pyprof\n\nclass Foo(torch.nn.Module):\n    def __init__(self, size):\n        super(Foo, self).__init__()\n        self.n = torch.nn.Parameter(torch.ones(size))\n        self.m = torch.nn.Parameter(torch.ones(size))\n\n    def forward(self, input):\n        return self.n*input + self.m\n\nfoo = Foo(4)\nfoo.cuda()\nx = torch.ones(4).cuda()\n\n#JIT the class using tracing\ntraced_foo = torch.jit.trace(foo, x)\n\n#Initialize pyprof after the JIT step\npyprof.nvtx.init()\n\n#Assign a name to the object \"traced_foo\"\ntraced_foo.__dict__['__name__'] = \"foo\"\n\n#Hook up the forward function to pyprof\npyprof.nvtx.wrap(traced_foo, 'forward')\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\tz = traced_foo(x)\n\tprofiler.stop()\n\tprint(z)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/jit/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo $sql python $f\"\n\tnvprof -fo $sql python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t$prof -w 130 $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/lenet.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.cuda.profiler as profiler\nimport torch.optim as optim\n\nfrom apex import pyprof\npyprof.nvtx.init()\n\nclass LeNet5(nn.Module):\n\tdef __init__(self):\n\t\tsuper(LeNet5, self).__init__()\n\t\t# 1 input image channel, 6 output channels, 5x5 square convolution\n\t\t# kernel\n\t\tself.conv1 = nn.Conv2d(1, 6, 5)\n\t\tself.conv2 = nn.Conv2d(6, 16, 5)\n\t\t# an affine operation: y = Wx + b\n\t\tself.fc1 = nn.Linear(16 * 5 * 5, 120)\n\t\tself.fc2 = nn.Linear(120, 84)\n\t\tself.fc3 = nn.Linear(84, 10)\n\n\tdef forward(self, x):\n\t\t# Max pooling over a (2, 2) window\n\t\tx = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n\t\t# If the size is a square you can only specify a single number\n\t\tx = F.max_pool2d(F.relu(self.conv2(x)), 2)\n\t\tx = x.view(-1, self.num_flat_features(x))\n\t\tx = F.relu(self.fc1(x))\n\t\tx = F.relu(self.fc2(x))\n\t\tx = self.fc3(x)\n\t\treturn x\n\n\tdef num_flat_features(self, x):\n\t\tsize = x.size()[1:]  # all dimensions except the batch dimension\n\t\tnum_features = 1\n\t\tfor s in size:\n\t\t\tnum_features *= s\n\t\treturn num_features\n\nwith torch.autograd.profiler.emit_nvtx():\n\n\tnet = LeNet5().cuda()\n\n\tinput = torch.randn(1, 1, 32, 32).cuda()\n\tout = net(input)\n\n\ttarget = torch.randn(10)\t\t\t# a dummy target, for example\n\ttarget = target.view(1, -1).cuda()\t# make it the same shape as output\n\tcriterion = nn.MSELoss()\n\n\t# create your optimizer\n\toptimizer = optim.SGD(net.parameters(), lr=0.01)\n\n\t# in your training loop:\n\toptimizer.zero_grad()\t# zero the gradient buffers\n\n\tprofiler.start()\n\toutput = net(input)\n\tloss = criterion(output, target)\n\tloss.backward()\n\toptimizer.step()\t# Does the update\n\tprofiler.stop()\n\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/operators.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nThis file checks all Python operators.\n\"\"\"\n\nimport sys\nimport torch\nimport torch.cuda.profiler as profiler\nimport operator\nimport inspect\n\n#Import and initialize pyprof\nfrom apex import pyprof\npyprof.nvtx.init()\n\nX = 1024\nY = 1024\n\nfa = torch.rand(X, Y).cuda()\nfb = torch.rand(X, Y).cuda()\nfc = torch.rand(X, Y).cuda()\n\nia = torch.randint(0, 100, (X, Y)).cuda()\nib = torch.randint(0, 100, (X, Y)).cuda()\n\nsa = torch.ones(1,1).cuda()\nsb = torch.ones(1,1).cuda()\n\nba = fa.byte()\n\nunaryOps = [\"abs\", \"__abs__\", \"neg\", \"__neg__\",]\ninvertOps = [\"inv\", \"invert\", \"__inv__\", \"__invert__\",]\t#imlemented only for byte tensors\n#pos, __pos__ is not implemented for tensors\n\nbinaryOps = []\nbinaryOps += [ \"lt\", \"__lt__\", \"le\", \"__le__\", \"eq\", \"__eq__\", \"ne\", \"__ne__\", \"ge\", \"__ge__\", \"gt\", \"__gt__\" ]\nbinaryOps += [ \"add\", \"__add__\", \"sub\", \"__sub__\", \"mul\", \"__mul__\", \"floordiv\", \"__floordiv__\", \"truediv\", \"__truediv__\", \"pow\", \"__pow__\", \"mod\", \"__mod__\"]\nbinaryOps += [ \"and_\", \"__and__\", \"or_\", \"__or__\", \"xor\", \"__xor__\", \"lshift\", \"__lshift__\", \"rshift\", \"__rshift__\"]\n\ninplaceOps = []\ninplaceOps += [\"iadd\", \"__iadd__\", \"isub\", \"__isub__\", \"imul\", \"__imul__\", \"ifloordiv\", \"__ifloordiv__\", \"itruediv\", \"__itruediv__\", \"imod\", \"__imod__\",]\n#ipow, __ipow__ is not implemented in pytorch\ninplaceOps += [ \"iand\", \"__iand__\", \"ior\", \"__ior__\", \"ixor\", \"__ixor__\", \"ilshift\", \"__ilshift__\", \"irshift\", \"__irshift__\",]\n\nmatmulOps = [ \"matmul\", \"__matmul__\" ]\ninplacematmulOps = [ \"imatmul\", \"__imatmul__\" ]\n\nreverseIntBinaryOps = [\"__radd__\", \"__rsub__\", \"__rmul__\", \"__rfloordiv__\", \"__rpow__\",]\nreverseFloatBinaryOps = [\"__radd__\", \"__rsub__\", \"__rmul__\", \"__rdiv__\", \"__rtruediv__\", \"__rfloordiv__\", \"__rpow__\",]\n\n'''\nTODO\n.concat(a, b)\n.__concat__(a, b)\n.contains(a, b)\n.__contains__(a, b)\n.countOf(a, b)\n.delitem(a, b)\n.__delitem__(a, b)\n.getitem(a, b)\n.__getitem__(a, b)\n.indexOf(a, b)\n.setitem(a, b, c)\n.__setitem__(a, b, c)\n.length_hint(obj, default=0)\n.iconcat(a, b)\n.__iconcat__(a, b)\n.index(a)\n.__index__(a)\n'''\n\n#Context manager\nwith torch.autograd.profiler.emit_nvtx():\n\n\t#Start profiler\n\tprofiler.start()\n\n\tfor op in unaryOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(ia)\n\n\tfor op in invertOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(ba)\n\n\tfor op in binaryOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(ia, ib)\n\t\tc = f(ia, 2)\n\n\tfor op in inplaceOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tia = f(ia, ib)\n\t\tia = f(ia, 2)\n\n\tfor op in matmulOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tc = f(fa, fb)\n\n\tfor op in inplacematmulOps:\n\t\tassert hasattr(operator, op)\n\t\tf = getattr(operator, op)\n\t\tassert inspect.isbuiltin(f)\n\t\tfa = f(fa, fb)\n\n\tfor op in reverseIntBinaryOps:\n\t\tassert hasattr(torch.Tensor, op)\n\t\tf = getattr(torch.Tensor, op)\n\t\tia = f(ia, ib)\n\n\tfor op in reverseFloatBinaryOps:\n\t\tassert hasattr(torch.Tensor, op)\n\t\tf = getattr(torch.Tensor, op)\n\t\tfa = f(fa, fb)\n\n\t'''\n\t#c = fa[3]\n\t#c = fa[3][3]\n\t#c = torch.min(fa, 3)\n\tc = torch.sum(fa)\n\tc = torch.max(fa)\n\tc = -fa\n\t#fc[2][2] = fa[2][2]\n\n\tc = a_scalar and b_scalar\n\tc = a_scalar or b_scalar\n\tc = not a_scalar\n\n\tc = a is b\n\tc = a is not b\n\t'''\n\n\t#Stop profiler\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/simple.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nThis simple file provides an example of how to\n - import the pyprof library and initialize it\n - use the emit_nvtx context manager\n - start and stop the profiler\n\nOnly kernels within profiler.start and profiler.stop calls are profiled.\nTo profile\n$ nvprof -f -o simple.sql --profile-from-start off ./simple.py\n\"\"\"\n\nimport sys\nimport torch\nimport torch.cuda.profiler as profiler\n\n#Import and initialize pyprof\nfrom apex import pyprof\npyprof.nvtx.init()\n\na = torch.randn(5, 5).cuda()\nb = torch.randn(5, 5).cuda()\n\n#Context manager\nwith torch.autograd.profiler.emit_nvtx():\n\n\t#Start profiler\n\tprofiler.start()\n\n\tc = a + b\n\tc = torch.mul(a,b)\n\tc = torch.matmul(a,b)\n\tc = torch.argmax(a, dim=1)\n\tc = torch.nn.functional.pad(a, (1,1))\n\n\t#Stop profiler\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/user_annotation/README.md",
    "content": "Nvidia NVTX range markers (https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) \nare a useful tool to capture and observe events and code ranges etc. \nUsing PyTorch APIs e.g, `torch.cuda.nvtx.range_push(\"xxx\")` and `torch.cuda.nvtx.range_pop()` users can easily add their own NVTX range markers. These markers can then be observed in the Nvidia Visual Profiler (NVVP).\n\nWhile inserting NVTX markers (strings), if the users follow a specific string pattern `\"layer:your_string_here\"` e.g. `\"layer:conv1\"` or `\"layer:encoder_layer_3_self_attention`, then `pyprof` will display the strings `conv1` and `encoder_layer_3_self_attention` next to the associated kernels in the output of `prof.py` when used with the `-c layer` option.\n\nNVTX range markers can be nested and if users follow the above string pattern, the output of `prof.py` will show all the markers associated with a kernel.\n\nThe file `resnet.py` (a simplified version of the torchvision model) shows an example of how users can add (nested) NVTX markers with information which can greatly aid in understanding and analysis of networks.\n\nNote that the pattern `\"layer:your_string_here\"` was chosen to aid information extraction by `pyprof`. The tool will work seamlessly even if there are other markers or no markers at all.\n\n### To run\n\n```sh\nnvprof -fo resnet.sql --profile-from-start off python resnet.py\nparse.py resnet.sql > resnet.dict\nprof.py --csv -c idx,layer,dir,mod,op,kernel,params,sil resnet.dict\n```\n\nThe file `resnet.sql` can also be opened with NVVP as usual.\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/user_annotation/resnet.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nAn example showing use of nested NVTX markers.\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\nimport torch.cuda.profiler as profiler\nimport torch.cuda.nvtx as nvtx\nfrom apex import pyprof\npyprof.nvtx.init()\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n\t\"\"\"3x3 convolution with padding\"\"\"\n\treturn nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n\t\t\t\t\t padding=dilation, groups=groups, bias=False, dilation=dilation)\n\ndef conv1x1(in_planes, out_planes, stride=1):\n\t\"\"\"1x1 convolution\"\"\"\n\treturn nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\nclass Bottleneck(nn.Module):\n\texpansion = 4\n\tcount = 1\n\n\tdef __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n\t\t\t\t base_width=64, dilation=1, norm_layer=None):\n\t\tsuper(Bottleneck, self).__init__()\n\t\tif norm_layer is None:\n\t\t\tnorm_layer = nn.BatchNorm2d\n\t\twidth = int(planes * (base_width / 64.)) * groups\n\t\t# Both self.conv2 and self.downsample layers downsample the input when stride != 1\n\t\tself.conv1 = conv1x1(inplanes, width)\n\t\tself.bn1 = norm_layer(width)\n\t\tself.conv2 = conv3x3(width, width, stride, groups, dilation)\n\t\tself.bn2 = norm_layer(width)\n\t\tself.conv3 = conv1x1(width, planes * self.expansion)\n\t\tself.bn3 = norm_layer(planes * self.expansion)\n\t\tself.relu = nn.ReLU(inplace=True)\n\t\tself.downsample = downsample\n\t\tself.stride = stride\n\n\t\tself.id = Bottleneck.count\n\t\tBottleneck.count += 1\n\n\tdef forward(self, x):\n\t\tidentity = x\n\n\t\tnvtx.range_push(\"layer:Bottleneck_{}\".format(self.id))\n\n\t\tnvtx.range_push(\"layer:Conv1\")\n\t\tout = self.conv1(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:BN1\")\n\t\tout = self.bn1(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:ReLU\")\n\t\tout = self.relu(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:Conv2\")\n\t\tout = self.conv2(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:BN2\")\n\t\tout = self.bn2(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:ReLU\")\n\t\tout = self.relu(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:Conv3\")\n\t\tout = self.conv3(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:BN3\")\n\t\tout = self.bn3(out)\n\t\tnvtx.range_pop()\n\n\t\tif self.downsample is not None:\n\t\t\tnvtx.range_push(\"layer:Downsample\")\n\t\t\tidentity = self.downsample(x)\n\t\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:Residual\")\n\t\tout += identity\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:ReLU\")\n\t\tout = self.relu(out)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_pop()\n\n\t\treturn out\n\nclass ResNet(nn.Module):\n\n\tdef __init__(self, block, layers, num_classes=1000,\n\t\t\t\t groups=1, width_per_group=64, norm_layer=None):\n\t\tsuper(ResNet, self).__init__()\n\t\tif norm_layer is None:\n\t\t\tnorm_layer = nn.BatchNorm2d\n\t\tself._norm_layer = norm_layer\n\n\t\tself.inplanes = 64\n\t\tself.dilation = 1\n\n\t\tself.groups = groups\n\t\tself.base_width = width_per_group\n\t\tself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n\t\tself.bn1 = norm_layer(self.inplanes)\n\t\tself.relu = nn.ReLU(inplace=True)\n\t\tself.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\t\tself.layer1 = self._make_layer(block, 64, layers[0])\n\t\tself.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n\t\tself.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n\t\tself.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n\t\tself.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n\t\tself.fc = nn.Linear(512 * block.expansion, num_classes)\n\n\t\tfor m in self.modules():\n\t\t\tif isinstance(m, nn.Conv2d):\n\t\t\t\tnn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n\t\t\telif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n\t\t\t\tnn.init.constant_(m.weight, 1)\n\t\t\t\tnn.init.constant_(m.bias, 0)\n\n\tdef _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n\t\tnorm_layer = self._norm_layer\n\t\tdownsample = None\n\t\tprevious_dilation = self.dilation\n\t\tif dilate:\n\t\t\tself.dilation *= stride\n\t\t\tstride = 1\n\t\tif stride != 1 or self.inplanes != planes * block.expansion:\n\t\t\tdownsample = nn.Sequential(\n\t\t\t\tconv1x1(self.inplanes, planes * block.expansion, stride),\n\t\t\t\tnorm_layer(planes * block.expansion),\n\t\t\t)\n\n\t\tlayers = []\n\t\tlayers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n\t\t\t\t\t\t\tself.base_width, previous_dilation, norm_layer))\n\t\tself.inplanes = planes * block.expansion\n\t\tfor _ in range(1, blocks):\n\t\t\tlayers.append(block(self.inplanes, planes, groups=self.groups,\n\t\t\t\t\t\t\t\tbase_width=self.base_width, dilation=self.dilation,\n\t\t\t\t\t\t\t\tnorm_layer=norm_layer))\n\n\t\treturn nn.Sequential(*layers)\n\n\tdef forward(self, x):\n\n\t\tnvtx.range_push(\"layer:conv1_x\")\n\t\tx = self.conv1(x)\n\t\tx = self.bn1(x)\n\t\tx = self.relu(x)\n\t\tx = self.maxpool(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv2_x\")\n\t\tx = self.layer1(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv3_x\")\n\t\tx = self.layer2(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv4_x\")\n\t\tx = self.layer3(x)\n\t\tnvtx.range_pop()\n\n\t\tnvtx.range_push(\"layer:conv5_x\")\n\t\tx = self.layer4(x)\n\t\tnvtx.range_pop()\n\n\t\tx = self.avgpool(x)\n\t\tx = torch.flatten(x, 1)\n\n\t\tnvtx.range_push(\"layer:FC\")\n\t\tx = self.fc(x)\n\t\tnvtx.range_pop()\n\n\t\treturn x\n\n\ndef resnet50():\n\treturn ResNet(Bottleneck, [3, 4, 6, 3])\n\n#Create model\nnet = resnet50().cuda().half()\nnet.train()\n\n#Create optimizer\ncriterion = nn.CrossEntropyLoss().cuda()\noptimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n\n#Create synthetic input and label\nx = torch.rand(32, 3, 224, 224).cuda().half()\ntarget = torch.empty(32, dtype=torch.long).random_(1000).cuda()\n\nwith torch.autograd.profiler.emit_nvtx():\n\tprofiler.start()\n\toutput = net(x)\n\tloss = criterion(output, target)\n\toptimizer.zero_grad()\n\tloss.backward()\n\toptimizer.step()\n\tprofiler.stop()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/examples/user_annotation/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nSCRIPT=`realpath $0`\nSCRIPTPATH=`dirname $SCRIPT`\nPYPROF=\"$SCRIPTPATH/../..\"\n\nparse=\"python $PYPROF/parse/parse.py\"\nprof=\"python $PYPROF/prof/prof.py\"\n\nfor f in *.py\ndo\n\tbase=`basename $f .py`\n\tsql=$base.sql\n\tdict=$base.dict\n\n\t#NVprof\n\techo \"nvprof -fo --profile-from-start off $sql python $f\"\n\tnvprof -fo $sql --profile-from-start off python $f\n\n\t#Parse\n\techo $parse $sql\n\t$parse $sql > $dict\n\n\t#Prof\n\techo $prof $dict\n\t#$prof -w 130 $dict\n\t$prof --csv -c idx,layer,dir,mod,op,kernel,params,sil $dict\n\t\\rm $sql $dict\ndone\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/nvtx/__init__.py",
    "content": "from .nvmarker import init\nfrom .nvmarker import add_wrapper as wrap\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/nvtx/nvmarker.py",
    "content": "\"\"\"\nThis file intercepts (monkey patches) the following functions and adds NVTX markers.\n\ttorch.*\n\ttorch.Tensor.*\n\ttorch.nn.functional.*\n\ttorch.nn.*.forward\n\nThe NVTX markers (one or more) contain the following information\n\tcall trace (a list of file_name:line_number)\n\textra_repr() from torch.nn modules\n\tmodule/class name\n\tfunction name\n\tinputs (args and kwargs)\n\t\tscalar: name, type and value\n\t\ttensor: name, shape and datatype\n\t\tnumpy: name, shape and datatype\n\t\tlist/tuple: a sequence of scalars or tensors or numpy arrays\n\"\"\"\n\nimport torch\nimport torch.cuda.nvtx as nvtx\nimport numpy\nimport inspect as ins\nimport traceback\nimport math\n\ndef isfunc(mod, f):\n\tassert hasattr(mod, f)\n\tattr = getattr(mod, f)\n\n\t#Ignore functions like _add\n\tif (len(f) >= 2):\n\t\tif f[0] == \"_\" and f[1] != \"_\":\n\t\t\treturn False\n\n\t#Ignore functions from this list\n\tignore = ['__all__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__builtins__', '__cached__', '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__file__', '__format__', '__getattribute__', '__getitem__', '__hash__', '__index__', '__init__', '__init_subclass__', '__iter__', '__len__', '__loader__', '__module__', '__name__', '__new__', '__nonzero__', '__package__', '__path__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__spec__', '__str__', '__subclasshook__', '__version__', '__weakref__']\n\n\t#Add functions to this list if they cause recursion\n\tignore += ['size', 'tolist', 'dim', 'is_storage', 'item']\n\tif f in ignore:\n\t\treturn False\n\n\treturn ins.ismethod(attr) or ins.isfunction(attr) or ins.ismethoddescriptor(attr) or ins.isbuiltin(attr)\n\ndef traceMarker(stack):\n\td = {}\n\tcadena = []\n\tfor i in range(len(stack)-1):\n\t\tfi = stack[i]\n\t\tt = \"{}:{}\".format(fi.filename, fi.lineno)\n\t\tcadena.append(t)\n\td['traceMarker'] = cadena\n\treturn str(d)\n\ndef modMarker(mod, fn_name, args):\n\t\"\"\"\n\tReturns the stringified extra_repr() of a module.\n\t\"\"\"\n\tassert(fn_name == 'forward')\n\tassert(len(args) > 0)\n\td = {}\n\td['mod'] = mod.__name__\n\td['strRepr'] = args[0].extra_repr()\n\treturn str(d)\n\ndef add_wrapper(mod, fn_name):\n\tassert isfunc(mod, fn_name)\n\n\t# Get a pointer to the original function\n\tfunc = getattr(mod, fn_name)\n\n\t# Check if the mod has a string representation\n\t# and is not a Script or Traced module (used by JIT)\n\ts = hasattr(mod, \"extra_repr\") and (type(mod) is not torch.jit.ScriptModule) and (type(mod) is not torch.jit.TopLevelTracedModule)\n\n\tdef wrapper_func(*args, **kwargs):\n\n\t\t# Extract the stacktrace\n\t\tstack = traceback.extract_stack()\n\n\t\t# Push trace marker\n\t\tnvtx.range_push(traceMarker(stack))\n\n\t\t# Push module marker\n\t\tif s:\n\t\t\tm = modMarker(mod, fn_name, args)\n\t\t\tnvtx.range_push(m)\n\n\t\t# Create and push argument marker\n\t\tcadena = argMarker(mod, fn_name, args, kwargs)\n\t\tnvtx.range_push(cadena)\n\n\t\t# Call the original function\n\t\tresult = func(*args, **kwargs)\n\n\t\t# Pop argumet marker\n\t\tnvtx.range_pop()\n\n\t\t# Pop module marker\n\t\tif s:\n\t\t\tnvtx.range_pop()\n\n\t\t# Pop trace marker\n\t\tnvtx.range_pop()\n\n\t\treturn result\n\tsetattr(mod, fn_name, wrapper_func)\n\ndef argMarker(mod, op, args, kwargs):\n\t#For this function args is a tuple and kwargs is a dict\n\n\tdef tensor(arg, name=\"\"):\n\t\ta = {}\n\t\ta['name'] = name\n\t\ta['type'] = \"tensor\"\n\t\ta['shape'] = tuple(arg.size())\n\t\ta['dtype'] = str(arg.dtype).split(\".\")[-1]\n\t\tcadena['args'].append(a)\n\n\tdef ndarray(arg, name=\"\"):\n\t\ta = {}\n\t\ta['name'] = name\n\t\ta['type'] = \"ndarray\"\n\t\ta['shape'] = arg.shape\n\t\ta['dtype'] = str(arg.dtype).split(\".\")[-1]\n\t\tcadena['args'].append(a)\n\n\tdef seq(arg, name=\"\"):\n\t\tassert issequence(arg)\n\t\ta = {}\n\t\ta['name'] = name\n\t\tif isinstance(arg, list):\n\t\t\ta['type'] = \"list\"\n\t\t\ta['value'] = arg\n\t\telse:\n\t\t\ta['type'] = \"tuple\"\n\t\t\t# The arg could be torch.Size, which is a subclass of tuple\n\t\t\t# Therefore, explicitly convert to tuple\n\t\t\ta['value'] = tuple(arg)\n\t\t\n\t\tcadena['args'].append(a)\n\n\tdef scalar(arg, name=\"\"):\n\t\ta = {}\n\t\ta['name'] = name\n\t\ta['type'] = type(arg).__name__\n\t\t#handle the case when the argument is +/- inf or nan\n\t\tif arg == float('inf'):\n\t\t\ta['value'] = \"inf\"\n\t\telif arg == float('-inf'):\n\t\t\ta['value'] = \"-inf\"\n\t\telif isinstance(arg, float) and math.isnan(arg):\n\t\t\ta['value'] = \"nan\"\n\t\telse:\n\t\t\ta['value'] = arg\n\t\tcadena['args'].append(a)\n\n\tdef isscalar(arg):\n\t\treturn (type(arg) is int) or (type(arg) is float) or (type(arg) is bool) or (arg is None) or (type(arg) is str)\n\n\tdef issequence(arg):\n\t\treturn isinstance(arg, list) or isinstance(arg, tuple)\n\n\tdef foo(args, name):\n\t\t#args should be an iterable sequence e.g. list or tuple\n\t\tfor arg in args:\n\t\t\tif isinstance(arg, torch.Tensor):\n\t\t\t\tif arg.dim() == 0:\n\t\t\t\t\tscalar(arg.item(), name)\n\t\t\t\telse:\n\t\t\t\t\ttensor(arg, name)\n\t\t\telif isinstance(arg, numpy.ndarray):\n\t\t\t\tndarray(arg, name)\n\t\t\telif (isscalar(arg)):\n\t\t\t\tscalar(arg, name)\n\t\t\telif issequence(arg):\n\t\t\t\tif (len(arg) == 0) or isscalar(arg[0]):\t#An empty sequence or a sequence of scalars\n\t\t\t\t\tseq(arg, name)\n\t\t\t\telse:\t# A sequence of tensors or numpy arrays\n\t\t\t\t\tfoo(arg, name)\n\t\t\t'''\n\t\t\telse:\n\t\t\t\tprint(\"The following arg is none of Tensor, numpy array, scalar but a %s\" % (str(type(arg))))\n\t\t\t\tprint(\"Mod: %s\" % str(mod.__name__))\n\t\t\t\tprint(\"Op: %s\" % str(op))\n\t\t\t\tprint(dir(arg))\n\t\t\t'''\n\n\tcadena = {}\n\tcadena['mod'] = mod.__name__\n\tcadena['op'] = op\n\tcadena['args'] = []\n\n\tfoo(args, \"\")\n\tfor k,v in kwargs.items():\n\t\tfoo((v,), k)\n\n\treturn str(cadena)\n\ndef patchClass(cls):\n\tfor f in dir(cls):\n\t\tif isfunc(cls, f):\n\t\t\tadd_wrapper(cls, f)\n\ndef init():\n\tstring = \"\\n\\nPyprof has been moved to its own dedicated repository and will \" + \\\n\t\t\t\"soon be removed from Apex.  Please visit\\n\" + \\\n\t\t\t\"https://github.com/NVIDIA/PyProf\\n\" + \\\n\t\t\t\"for the latest version.\\n\\n\"\n\t# print regardless of warning state\n\tprint(string)\n\n\tprint(\"Initializing NVTX monkey patches\")\n\tfor cls in [torch, torch.Tensor, torch.nn.functional,]:\n\t\tpatchClass(cls)\n\n\tfor cls in [torch.nn.RNN, torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell, torch.nn.GRU, torch.nn.GRUCell]:\n\t\tif isfunc(cls, 'forward'):\n\t\t\tadd_wrapper(cls, 'forward')\n\n\tprint(\"Done with NVTX monkey patching\")\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/parse/__init__.py",
    "content": ""
  },
  {
    "path": "KoSimCSE/apex/pyprof/parse/__main__.py",
    "content": "import warnings\n\ntry:\n    from .parse import main\nexcept ImportError as e:\n    warnings.warn(\"Did you make sure to install PyProf dependencies by using the --pyprof flag during Apex installation?)\")\n    raise e\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/parse/db.py",
    "content": "import sys, sqlite3\n\nclass DB(object):\n\t\"\"\"\n\tThis class provides functions for DB operations\n\twith exception handling.\n\t\"\"\"\n\n\tdef __init__(self, dbFile):\n\t\ttry:\n\t\t\tconn = sqlite3.connect(dbFile)\n\t\t\tconn.row_factory = sqlite3.Row\n\t\t\tc = conn.cursor()\n\t\texcept:\n\t\t\tprint(\"Error opening {}\".format(dbFile))\n\t\t\tsys.exit(1)\n\n\t\tself.conn = conn\n\t\tself.c = c\n\n\tdef select(self, cmd):\n\t\ttry:\n\t\t\tself.c.execute(cmd)\n\t\t\t#rows = self.c.fetchall()\n\t\t\trows = [dict(row) for row in self.c.fetchall()]\n\t\texcept sqlite3.Error as e:\n\t\t\tprint(e)\n\t\t\tsys.exit(1)\n\t\texcept:\n\t\t\tprint(\"Uncaught error in SQLite access while executing {}\".format(cmd))\n\t\t\tsys.exit(1)\n\n\t\t#print(rows)\n\t\treturn rows\n\n\tdef insert(self, cmd, data):\n\t\ttry:\n\t\t\tself.c.execute(cmd, data)\n\t\texcept sqlite3.Error as e:\n\t\t\tprint(e)\n\t\t\tsys.exit(1)\n\t\texcept:\n\t\t\tprint(\"Uncaught error in SQLite access while executing {}\".format(cmd))\n\t\t\tsys.exit(1)\n\n\tdef execute(self, cmd):\n\t\ttry:\n\t\t\tself.c.execute(cmd)\n\t\texcept sqlite3.Error as e:\n\t\t\tprint(e)\n\t\t\tsys.exit(1)\n\t\texcept:\n\t\t\tprint(\"Uncaught error in SQLite access while executing {}\".format(cmd))\n\t\t\tsys.exit(1)\n\n\tdef commit(self):\n\t\tself.conn.commit()\n\n\tdef close(self):\n\t\tself.c.close()\n\t\tself.conn.close()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/parse/kernel.py",
    "content": "import cxxfilt, struct, binascii\n\n#Helper functions\n\ndef demangle(name):\n\t\"\"\"\n\tDemangle a C++ string\n\t\"\"\"\n\treturn cxxfilt.demangle(name)\n\ndef encode_object_id(pid, tid):\n\t\"\"\"\n\tGiven process id (pid) and thread id (tid), return the object id.\n\tobject id = pid (little endian 4 bytes) + tid (little endian 8 bytes)\n\t\"\"\"\n\tobjId = struct.pack('<i', pid) + struct.pack('<q',tid)\n\tobjId = binascii.hexlify(objId).decode('ascii').upper()\n\treturn objId\n\ndef getShortName(name):\n\t\"\"\"\n\tReturns a shorter kernel name\n\t\"\"\"\n\tsname = name.split(\"<\")[0] \\\n\t\t\t\t.replace(\"void \", \"\") \\\n\t\t\t\t.replace(\"at::\",\"\") \\\n\t\t\t\t.replace(\"cuda::\", \"\") \\\n\t\t\t\t.replace(\"native::\",\"\") \\\n\t\t\t\t.replace(\"(anonymous namespace)::\", \"\")\n\tsname = sname.split(\"(\")[0]\n\treturn sname\n\nclass Kernel(object):\n\t\"\"\"\n\tThis class stores information about a kernel.\n\t\"\"\"\n\n\tkernels = []\n\tprofStart = 0\n\n\tdef __init__(self):\n\t\tself.kNameId = None\n\t\tself.kShortName = None\n\t\tself.kLongName = None\n\t\tself.kStartTime = None\t#GPU start time\n\t\tself.kEndTime = None\t#GPU end time\n\t\tself.kDuration = None\n\t\tself.device = None\n\t\tself.stream = None\n\t\tself.grid = ()\n\t\tself.block = ()\n\t\tself.corrId = None\n\t\tself.rStartTime = None\t#CPU start time\n\t\tself.rEndTime = None\t#CPU end time\n\t\tself.rDuration = None\n\t\tself.tid = None\n\t\tself.pid = None\n\t\tself.objId = None\n\t\tself.timeOffset = None\n\n\t\tself.layerMarkers = []\n\t\tself.traceMarkers = []\n\t\tself.reprMarkers = []\n\t\tself.pyprofMarkers = []\n\t\tself.seqMarkers = []\n\t\tself.otherMarkers = []\n\t\tself.altMarkers = []\n\t\tself.seqId = []\n\t\tself.altSeqId = []\n\t\tself.layer = []\n\n\t\tself.subSeqId = None\n\t\tself.dir = None\n\t\tself.mod = []\n\t\tself.op = []\n\n\tdef setKernelInfo(self, info):\n\t\tself.kNameId = info['name']\n\t\tself.corrId = int(info['correlationId'])\n\t\tstart = int(info['start'])\n\t\tend = int(info['end'])\n\t\tassert end > start, \"This assertion can fail for very large profiles. It usually fails when start = end = 0.\"\n\t\tself.kStartTime = start\n\t\tself.kEndTime = end\n\t\tself.kDuration = end - start\n\t\tassert (start > Kernel.profStart)\n\t\tself.device = int(info['deviceId'])\n\t\tself.stream = int(info['streamId'])\n\t\tself.grid = (info['gridX'], info['gridY'], info['gridZ'])\n\t\tself.block = (info['blockX'], info['blockY'], info['blockZ'])\n\t\tself.timeOffset = Kernel.profStart\n\n\tdef setKernelName(self, name):\n\t\tcadena = demangle(name)\n\t\tself.kLongName = cadena\n\t\tself.kShortName = getShortName(cadena)\n\n\tdef setRunTimeInfo(self, info):\n\t\tstart, end, pid, tid = info\n\t\tself.rStartTime = start\n\t\tself.rEndTime = end\n\t\tself.rDuration = end - start\n\t\tself.pid = pid\n\t\tself.tid = tid\n\t\tself.objId = encode_object_id(pid, tid)\n\n\tdef setMarkerInfo(self, info):\n\t\tself.layerMarkers, self.traceMarkers, self.reprMarkers, self.pyprofMarkers, self.seqMarkers, self.otherMarkers, self.altMarkers, self.seqId, self.altSeqId, self.layer = info\n\t\tself.subSeqId = 0\n\n\tdef setDirection(self):\n\t\t\"\"\"\n\t\tSet direction (fprop, bprop) based on PyTorch sequence markers.\n\t\tIt is a heuristic and not a foolproof method.\n\t\t\"\"\"\n\t\tif\tany(\"Backward, seq = \" in x for x in self.seqMarkers) or \\\n\t\t\tany(\"backward, seq = \" in x for x in self.seqMarkers) or \\\n\t\t\tany(\"Backward0, seq = \" in x for x in self.seqMarkers):\n\t\t\tself.dir = \"bprop\"\n\t\telse:\n\t\t\tself.dir = \"fprop\"\n\n\tdef setOp(self):\n\t\t\"\"\"\n\t\tDetect and set the class/module (mod) and operation (op)\n\t\tof the kernel e.g. torch.nn.functional / linear, torch / sigmoid.\n\t\tThe lookup sequence we use is\n\t\t\tNVTX markers inserted by pyprof\n\t\t\tNVTX markers inserted by PyTorch in bprop\n\t\t\tNVTX markers inserted by PyTorch in fprop\n\t\tIt is a heuristic and not a foolproof method.\n\t\t\"\"\"\n\n\t\tdef sanitize(name):\n\t\t\tname = name.replace(\"torch\",\"\") \\\n\t\t\t\t\t\t.replace(\"autograd\",\"\") \\\n\t\t\t\t\t\t.replace(\"_backward\",\"\") \\\n\t\t\t\t\t\t.replace(\"::\",\"\") \\\n\t\t\t\t\t\t.replace(\"jit\",\"\") \\\n\t\t\t\t\t\t.replace(\"(anonymous namespace)\",\"\")\n\t\t\thead, sep, tail = name.partition(\"Backward\")\n\t\t\treturn head\n\n\t\t#Check pyprof markers\n\t\tfor m in self.pyprofMarkers:\n\t\t\tassert (\"mod\" in m) and (\"op\" in m) and (\"args\" in m)\n\t\t\tt = eval(m)\n\t\t\tself.op.append(t['op'])\n\t\t\tself.mod.append(t['mod'])\n\n\t\tif len(self.op):\n\t\t\treturn\n\n\t\t#Check bprop kernel markers\n\t\tfor m in self.seqMarkers:\n\t\t\tif (\"backward, seq = \" in m) or (\"Backward, seq = \" in m):\n\t\t\t\top = m.split(\",\")[0]\n\t\t\t\top = sanitize(op)\n\t\t\t\tself.op.append(op)\n\t\t\t\tself.mod.append('na')\n\n\t\tif len(self.op):\n\t\t\treturn\n\n\t\t#Check markers with \"seq = \"\n\t\tfor m in self.seqMarkers:\n\t\t\tif \", seq = \" in m:\n\t\t\t\top = m.split(\",\")[0]\n\t\t\t\tself.op.append(op)\n\t\t\t\tself.mod.append('na')\n\n\t\tif len(self.op):\n\t\t\treturn\n\n\t\t#If nothing else\n\t\tif len(self.otherMarkers):\n\t\t\tself.op.append(self.otherMarkers[0])\n\t\tself.mod.append('na')\n\n\tdef print(self):\n\t\t\"\"\"\n\t\tPrint kernel information. This is used by prof.py.\n\t\t\"\"\"\n\n\t\ta = lambda: None\n\t\ta.kShortName = self.kShortName\n\t\ta.kDuration = self.kDuration\n\t\t#a.layerMarkers = self.layerMarkers\n\t\ta.layer = self.layer\n\t\ta.trace = self.traceMarkers\n\t\ta.reprMarkers = self.reprMarkers\n\t\ta.marker = self.pyprofMarkers\n\t\ta.seqMarker = self.seqMarkers\n\n\t\ta.seqId = self.seqId\n\t\ta.subSeqId = self.subSeqId\n\t\ta.altSeqId = self.altSeqId\n\n\t\ta.dir = self.dir\n\t\ta.mod = self.mod\n\t\ta.op = self.op\n\n\t\ta.tid = self.tid\n\t\ta.device = self.device\n\t\ta.stream = self.stream\n\t\ta.grid = self.grid\n\t\ta.block = self.block\n\t\ta.kLongName = self.kLongName\n\n\t\tprint(a.__dict__)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/parse/nvvp.py",
    "content": "import sys\n\nclass NVVP(object):\n\t\"\"\"\n\tThis class gets kernel information from the SQL (nvvp) database.\n\t\"\"\"\n\n\tdriverT = \"CUPTI_ACTIVITY_KIND_DRIVER\"\n\truntimeT = \"CUPTI_ACTIVITY_KIND_RUNTIME\"\n\tkernelT = \"CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL\"\n\tmarkerT = \"CUPTI_ACTIVITY_KIND_MARKER\"\n\tstringT = \"StringTable\"\n\n\tdef __init__(self, db):\n\t\tself.db = db\n\t\tself.markerId = 0\n\n\tdef getProfileStart(self):\n\t\t\"\"\"\n\t\tGet the profile start time\n\t\t\"\"\"\n\t\tprofStart = sys.maxsize\n\t\tfor table in [self.driverT, self.runtimeT, self.kernelT, self.markerT]:\n\t\t\tcolname = \"timestamp\" if table is self.markerT else \"start\"\n\t\t\tcmd = \"select {} from {} ORDER BY {} ASC LIMIT 1\".format(colname, table, colname)\n\t\t\tresult = self.db.select(cmd)\n\t\t\tassert(len(result) <= 1)\n\t\t\tif (len(result) == 1):\n\t\t\t\tassert(colname in result[0])\n\t\t\t\tt = result[0][colname]\n\t\t\t\tif (t < profStart):\n\t\t\t\t\tprofStart = t\n\t\tassert(profStart < sys.maxsize)\n\t\treturn profStart\n\n\tdef getString(self, id_):\n\t\t\"\"\"\n\t\tGet the string associated with an id.\n\t\t\"\"\"\n\t\tcmd = \"select value from {} where _id_ = {}\".format(self.stringT, id_)\n\t\tresult = self.db.select(cmd)\n\t\tassert (len(result) == 1)\n\t\treturn result[0]['value']\n\n\tdef createMarkerTable(self):\n\t\t\"\"\"\n\t\tCreate a temporary table and index it to speed up repeated SQL quesries.\n\t\tThe table is an INNER JOIN of CUPTI_ACTIVITY_KIND_MARKER with itself.\n\t\t\"\"\"\n\t\tcmd = 'CREATE TEMPORARY TABLE marker AS SELECT \\\n\t\t\t\t\ta._id_ as id, \\\n\t\t\t\t\ta.timestamp AS startTime, \\\n\t\t\t\t\tb.timestamp AS endTime, \\\n\t\t\t\t\tHEX(a.objectId) AS objectId, \\\n\t\t\t\t\ta.name AS name \\\n\t\t\t\t\tFROM {} AS a INNER JOIN {} AS b ON \\\n\t\t\t\t\ta.id = b.id and \\\n\t\t\t\t\ta.flags = 2 and b.flags = 4'.format(self.markerT, self.markerT)\n\t\tself.db.execute(cmd)\n\n\t\tself.db.execute('CREATE INDEX start_index ON marker (startTime)')\n\t\tself.db.execute('CREATE INDEX end_index ON marker (endTime)')\n\t\tself.db.execute('CREATE INDEX id_index ON marker (id)')\n\n\tdef getCPUInfo(self, corrId):\n\t\t\"\"\"\n\t\tGiven the correlation id, get CPU start, end, thread id, process id.\n\t\tThe information can be in the runtime table or the driver table.\n\t\t\"\"\"\n\n\t\t#First look in the runtime table\n\t\tcmd = \"select start,end,processId,threadId from {} where correlationId={}\".format(self.runtimeT, corrId);\n\t\tresult = self.db.select(cmd)\n\t\tassert (len(result) <= 1)\n\n\t\tif (len(result) == 0):\n\t\t\t#Look in the driver table\n\t\t\tcmd = \"select start,end,processId,threadId from {} where correlationId={}\".format(self.driverT, corrId);\n\t\t\tresult = self.db.select(cmd)\n\n\t\tassert (len(result) == 1)\n\t\tinfo = result[0]\n\t\tstart = info['start']\n\t\tend = info['end']\n\t\tpid = info['processId']\n\t\ttid = info['threadId']\n\t\ttid = tid & 0xffffffff\t#convert to unsigned\n\t\tassert (end > start)\n\t\treturn [start, end, pid, tid]\n\n\tdef getKernelInfo(self):\n\t\t\"\"\"\n\t\tGet GPU kernel info\n\t\t\"\"\"\n\t\tcmd = \"select name,correlationId,start,end,deviceId,streamId,gridX,gridY,gridZ,blockX,blockY,blockZ from {}\".format(self.kernelT)\n\t\tresult = self.db.select(cmd)\n\t\treturn result\n\n\tdef getMarkerInfo(self, objId, startTime, endTime):\n\t\t\"\"\"\n\t\tThis function first finds all NVTX markers encapsulating\n\t\ta runtime / driver kernel launch.\n\t\tIt then splits the markers into many lists.\n\t\t\tlayerMarkers : User added NVTX markers\n\t\t\ttraceMarkers : Call trace markers (inserted by pyprof)\n\t\t\treprMarkers  : Markers containing the extra_repr() of a module (inserted by pyprof)\n\t\t\tpyprofMarkers: Markers containing args and kwargs (tensor shape, datatype etc.)\n\t\t\tseqMarkers   : Markers containing PyTorch internal sequence markers (inserted by PyTorch)\n\t\t\taltSeqMarkers: Markers inserted by PyTorch between two kernel launches. Needs better explanation.\n\t\t\totherMarkers : Markers not in either of the above categories.\n\n\t\tWe extract seqId from the seq and altSeq markers. The seqId is used in bprop.\n\t\tWe also extract information from the layerMarkers.\n\t\t\"\"\"\n\n\t\tlayerMarkers = []\n\t\ttraceMarkers = []\n\t\treprMarkers = []\n\t\tpyprofMarkers = []\n\t\tseqMarkers = []\n\t\totherMarkers = []\n\t\taltSeqMarkers = []\n\t\tbprop = False\n\n\t\t#Helper functions\n\n\t\tdef delete(objId, sTime):\n\t\t\t\"\"\"\n\t\t\tDelete rows from the temporary SQL table which are no longer required.\n\t\t\tThis speeds up future queries.\n\t\t\t\"\"\"\n\t\t\tmargin = 0\n\t\t\tcmd = 'DELETE FROM marker WHERE objectId = \"{}\" AND endTime < {}'.format(objId, sTime - margin)\n\t\t\t#cmd = 'DELETE FROM marker WHERE endTime < {}'.format(sTime - margin)\n\t\t\tself.db.execute(cmd)\n\n\t\tdef getLayerName(mlist):\n\t\t\t\"\"\"\n\t\t\tGet layer names from layer marker list.\n\t\t\t\"\"\"\n\t\t\tlayers = []\n\t\t\tassert(type(mlist) == list)\n\t\t\tfor m in mlist:\n\t\t\t\tassert(\"layer:\" in m)\n\t\t\t\tl = m.split(\":\")[1]\n\t\t\t\tlayers.append(l)\n\t\t\treturn layers\n\n\t\tdef getSeqId(mlist):\n\t\t\t\"\"\"\n\t\t\tGet sequence ids from seq / alt seq marker list.\n\t\t\t\"\"\"\n\t\t\tids = []\n\t\t\tassert(type(mlist) == list)\n\t\t\tfor m in mlist:\n\t\t\t\tassert(\", seq = \" in m)\n\t\t\t\tseq = int(m.split(\"=\")[1])\n\t\t\t\tids.append(seq)\n\n\t\t\t#Remove duplicates\n\t\t\tids = list(set(ids))\n\t\t\tids.sort()\n\t\t\treturn ids\n\n\t\tdef seqcompare(elem):\n\t\t\t\"\"\"\n\t\t\tSorting function for sequence markers\n\t\t\t\"\"\"\n\t\t\tassert (\", seq = \" in elem)\n\t\t\t#sort by sequence id and then the string\n\t\t\tl = elem.split(\" = \")\n\t\t\treturn l[1] + l[0]\n\n\t\tdef prune(mlist):\n\t\t\t\"\"\"\n\t\t\tRemove markers with the same seqId and if the strings are similar.\n\t\t\tThis function works on a sorted sequence.\n\t\t\t\"\"\"\n\t\t\tassert (type(mlist) == list)\n\t\t\tassert (len(mlist))\n\t\t\ta = mlist[0:1]\n\t\t\tfor i in range(1,len(mlist)):\n\t\t\t\tm = mlist[i]\n\t\t\t\tpm = mlist[i-1]\n\t\t\t\tname,seq = m.split(\",\")\n\t\t\t\tpname,pseq = pm.split(\",\")\n\t\t\t\tsimilar = (name in pname) or (pname in name)\n\t\t\t\tif (seq == pseq) and similar:\n\t\t\t\t\tcontinue\n\t\t\t\telse:\n\t\t\t\t\ta.append(m)\n\t\t\treturn a\n\n\t\tdef filterTrace(mlist):\n\t\t\t\"\"\"\n\t\t\tFilter trace markers to remove certain file names.\n\t\t\t\"\"\"\n\t\t\tassert (type(mlist) == list)\n\t\t\tif len(mlist) == 0:\n\t\t\t\treturn mlist\n\t\t\tmlist = mlist[-1]\t#The last stack trace will be a super set.\n\t\t\tmlist = eval(mlist)\n\t\t\tmlist = mlist['traceMarker']\n\t\t\tassert (type(mlist) == list)\n\t\t\tmlist = list(filter(lambda x : \"/torch/nn/modules/\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/nn/functional.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/tensor.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/autograd/__init__.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/_jit_internal.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/pyprof/nvtx/nvmarker.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/apex/optimizers/\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/_utils.py\" not in x, mlist))\n\t\t\tmlist = list(filter(lambda x : \"/torch/optim/\" not in x, mlist))\n\t\t\treturn mlist\n\n\t\t#Find all encapsulating markers\n\t\tcmd = 'SELECT id,name from marker where \\\n\t\t\t\tobjectId = \"{}\" and \\\n\t\t\t\tstartTime < {} and \\\n\t\t\t\tendTime > {} \\\n\t\t\t\tORDER BY startTime ASC'.format(objId, startTime, endTime)\n\t\tresult = self.db.select(cmd)\n\n\t\t#Bin markers into different lists\n\t\tfor r in result:\n\t\t\tm = self.getString(r['name'])\n\n\t\t\t#Hack: If its a known gradient checkpointing marker, ignore it.\n\t\t\tif m.find(\"CheckpointFunctionBackward\") >= 0:\n\t\t\t\tcontinue\n\n\t\t\tif (\"_backward, seq =\" in m) or (\"Backward, seq =\" in m) or (\"Backward0, seq =\" in m):\n\t\t\t\tbprop = True\n\n\t\t\tif (\"mod\" in m) and (\"op\" in m) and (\"args\" in m) and (\"type\" in m):\n\t\t\t\tpyprofMarkers.append(m)\n\t\t\telif (\"layer:\" in m):\n\t\t\t\tlayerMarkers.append(m)\n\t\t\telif (\"traceMarker\" in m):\n\t\t\t\ttraceMarkers.append(m)\n\t\t\telif (\"strRepr\" in m):\n\t\t\t\treprMarkers.append(m)\n\t\t\telif (\", seq = \" in m):\n\t\t\t\tseqMarkers.append(m)\n\t\t\telse:\n\t\t\t\totherMarkers.append(m)\n\n\t\t#Remove duplicates, sort and prune seqMarkers\n\t\tif (len(seqMarkers)):\n\t\t\tseqMarkers = list(set(seqMarkers))\n\t\t\tseqMarkers.sort(key=seqcompare)\n\t\t\tseqMarkers = prune(seqMarkers)\n\n\t\t#Remove duplicates from otherMarkers\n\t\totherMarkers = list(set(otherMarkers))\n\n\t\t#Get markers with seq id (inserted by PyTorch) from the previous kernel to the present kernel\n\t\t#Only for fprop kernels\n\t\tif (len(result) and not bprop):\n\t\t\tloId = self.markerId\n\t\t\thiId = result[-1]['id']\n\t\t\tself.markerId = hiId\n\t\t\t\n\t\t\t#Get markers between loId and hiId\n\t\t\tcmd = 'SELECT id,name from marker where objectId = \"{}\" and id > {} and id < {} ORDER BY startTime ASC'.format(objId, loId, hiId)\n\t\t\tresult1 = self.db.select(cmd)\n\n\t\t\tfor r in result1:\n\t\t\t\tm = self.getString(r['name'])\n\t\t\t\t#Get only markers with seq id\n\t\t\t\tif (\", seq=\" in m):\n\t\t\t\t\taltSeqMarkers.append(m)\n\n\t\t\t#Remove duplicates, sort and prune altSeqMarkers\n\t\t\tif (len(altSeqMarkers)):\n\t\t\t\taltSeqMarkers = list(set(altSeqMarkers))\n\t\t\t\taltSeqMarkers.sort(key=seqcompare)\n\t\t\t\taltSeqMarkers = prune(altSeqMarkers)\n\n\t\tdelete(objId, startTime)\n\n\t\treturn layerMarkers, filterTrace(traceMarkers), reprMarkers, pyprofMarkers, seqMarkers, otherMarkers, altSeqMarkers, getSeqId(seqMarkers), getSeqId(altSeqMarkers), getLayerName(layerMarkers)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/parse/parse.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nParse the SQL db and print a dictionary for every kernel.\n\"\"\"\n\nimport sys\nimport argparse\nfrom tqdm import tqdm\n\nfrom .db import DB\nfrom .kernel import Kernel\nfrom .nvvp import NVVP\n\ndef parseArgs():\n\tparser = argparse.ArgumentParser(prog=sys.argv[0], description=\"Parse SQL (nvvp) db.\")\n\tparser.add_argument(\"file\",\n\t\ttype=str,\n\t\tdefault=None,\n\t\thelp=\"SQL db (nvvp) file.\")\n\n\targs = parser.parse_args()\n\treturn args\n\ndef main():\n\targs = parseArgs()\n\n\tdb = DB(args.file)\n\tnvvp = NVVP(db)\n\n\tkInfo = nvvp.getKernelInfo()\n\tif len(kInfo) == 0:\n\t\tprint(\"Found 0 kernels. Exiting.\", file=sys.stderr)\n\t\tdb.close()\n\t\tsys.exit(0)\n\telse:\n\t\tprint(\"Found {} kernels. Getting info for each kernel.\".format(len(kInfo)), file=sys.stderr)\n\n\tnvvp.createMarkerTable()\n\n\tprevSeqId = -1\n\tprevSubSeqId = -1\n\tprevOp = \"na\"\n\n\tKernel.profStart = nvvp.getProfileStart()\n\n\tfor i in tqdm(range(len(kInfo)), ascii=True):\n\t\tinfo = kInfo[i]\n\t\tk = Kernel()\n\n\t\t#Set kernel info\n\t\tk.setKernelInfo(info)\n\n\t\t#Get, set kernel name\n\t\tname = nvvp.getString(k.kNameId)\n\t\tk.setKernelName(name)\n\n\t\t#Get runtime info\n\t\tinfo = nvvp.getCPUInfo(k.corrId)\n\t\tk.setRunTimeInfo(info)\n\n\t\t#Get and set marker and seqid info\n\t\tinfo = nvvp.getMarkerInfo(k.objId, k.rStartTime, k.rEndTime)\n\t\tk.setMarkerInfo(info)\n\n\t\t#If the seqId contains both 0 and non zero integers, remove 0.\n\t\tif any(seq != 0 for seq in k.seqId) and (0 in k.seqId):\n\t\t\tk.seqId.remove(0)\n\n\t\t#Set direction (it uses seq id)\n\t\tk.setDirection()\n\n\t\t#Set op\n\t\tk.setOp()\n\n\t\t#The following code is based on heuristics.\n\t\t#TODO: Refactor.\n\t\t#Assign subSeqId, adjust seqId and altSeqId\n\t\t#seqId can be 0.\n\t\t#A kernel can have multiple seqIds both in fprop and bprop.\n\t\t#In bprop, seqIds might not decrease monotonically. I have observed a few blips.\n\t\tif len(k.seqId):\n\t\t\tassert (k.dir in [\"fprop\", \"bprop\"])\n\t\t\tif (k.dir == \"fprop\"):\n\t\t\t\t#Check if there is a sequence id larger than the previous\n\t\t\t\tinc = (k.seqId[-1] > prevSeqId)\n\t\t\t\tif inc:\n\t\t\t\t\tcurrSeqId = [x for x in k.seqId if x > prevSeqId][0]\n\t\t\t\telse:\n\t\t\t\t\tcurrSeqId = prevSeqId\n\t\t\telse:\n\t\t\t\tcurrSeqId = k.seqId[0]\n\n\t\t\t#if ((currSeqId == prevSeqId) and (k.op == prevOp)):\n\t\t\tif ((currSeqId == prevSeqId) and (k.op == prevOp)) or ((k.op[0] == \"forward\") and (k.op == prevOp) and (k.mod[0] in [\"LSTMCell\", \"GRUCell\", \"RNNCell\"])):\n\t\t\t\t#The second condition is to trap cases when pytorch does not use cudnn for a LSTMCell.\n\t\t\t\tk.subSeqId = prevSubSeqId + 1\n\n\t\t\tprevSeqId = currSeqId\n\t\t\tprevSubSeqId = k.subSeqId\n\t\t\tprevOp = k.op\n\n\t\t\t#Keep currSeqId in k.seqId, move everything else to k.altSeqId\n\t\t\tfor s in k.seqId:\n\t\t\t\tif s != currSeqId:\n\t\t\t\t\tk.seqId.remove(s)\n\t\t\t\t\tk.altSeqId.append(s)\n\n\t\t\tfor s in k.altSeqId:\n\t\t\t\tif s == currSeqId:\n\t\t\t\t\tk.altSeqId.remove(s)\n\n\t\t\tk.altSeqId = list(set(k.altSeqId))\n\t\t\tif (len(k.altSeqId)):\n\t\t\t\t(k.altSeqId).sort()\n\n\t\tk.print()\n\n\tdb.close()\n\nif __name__ == '__main__':\n\tmain()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/__init__.py",
    "content": "from . import data, prof\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/__main__.py",
    "content": "import warnings\n\ntry:\n    from .prof import main\nexcept ImportError as e:\n    warnings.warn(\"Did you make sure to install PyProf dependencies by using the --pyprof flag during Apex installation?\")\n    raise e\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/activation.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Activation(OperatorLayerBase):\n\t\"\"\"\n\tThis class handles the various activation functions.\n\t\"\"\"\n\n\tops = [\"celu\", \"elu\", \"elu_\", \"hardshrink\", \"hardtanh\", \"hardtanh_\", \"leaky_relu\", \"leaky_relu_\", \"logsigmoid\", \"prelu\", \"relu\", \"relu_\", \"relu6\", \"rrelu\", \"rrelu_\", \"selu\", \"sigmoid\", \"softplus\", \"softshrink\", \"softsign\", \"tanh\", \"tanhshrink\", \"threshold\", \"threshold_\"]\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch.nn.functional\", \"torch\", \"Tensor\"])\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) >= 1)\n\t\targ = args[0]\n\t\tassert (arg['type'] == \"tensor\")\n\n\t\tself.i = arg\n\t\tself.dir = d.dir\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.i['shape']),('type', self.i['dtype'])])\n\t\treturn p\n\n\tdef flops(self):\n\t\tdirection = self.dir\n\t\ttensor = self.i['shape']\n\t\tt = self.i['dtype']\n\n\t\t# TODO: revise\n\t\telems = Utility.numElems(tensor)\n\t\treturn elems\n\n\tdef bytes(self):\n\t\tdirection = self.dir\n\t\ttensor = self.i['shape']\n\t\tt = self.i['dtype']\n\n\t\telems = Utility.numElems(tensor)\n\t\telems = elems * (2 if direction == \"fprop\" else 3)\n\n\t\treturn elems * Utility.typeToBytes(t)\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/base.py",
    "content": "from abc import ABC, abstractmethod\n\nclass OperatorLayerBase(ABC):\n\t\"\"\"\n\tBase class for all layers and operators.\n\tEvery derived class should have the following functions.\n\t\"\"\"\n\n\t@abstractmethod\n\tdef tc(self):\n\t\t\"\"\"\n\t\tTensor core usage by the kernel.\n\t\tReturn \"1\" (yes), \"0\" (no, but possible), \"-\" (not applicable)\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef params(self):\n\t\t\"\"\"\n\t\tKernel parameters to be printed.\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef flops(self):\n\t\t\"\"\"\n\t\tNote that 1 FMA = 2 flops.\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef bytes(self):\n\t\tpass\n\n\t@abstractmethod\n\tdef mod(self):\n\t\t\"\"\"\n\t\tName of the module/class e.g. torch.nn.functional.\n\t\t\"\"\"\n\t\tpass\n\n\t@abstractmethod\n\tdef op(self):\n\t\t\"\"\"\n\t\tName of the operator e.g. sigmoid.\n\t\t\"\"\"\n\t\tpass\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/blas.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\nimport numpy as np\n\nTC_GEMMS = [\"884gemm\", \"1688gemm\"]\n\nclass Addmm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\",])\n\t\tassert (op in [\"addmm\", \"addmm_\",])\n\n\t\t#Get alpha and beta\n\t\talpha = 1\n\t\tbeta = 1\n\t\tif any(x['name'] == 'alpha' for x in args):\n\t\t\talpha = list(filter(lambda x : x['name'] == \"alpha\", args))[0]\n\t\t\talpha = alpha['value']\n\n\t\tif any(x['name'] == 'beta' for x in args):\n\t\t\tbeta = list(filter(lambda x : x['name'] == \"beta\", args))[0]\n\t\t\tbeta = beta['value']\n\n\t\tself.alpha = alpha\n\t\tself.beta = beta\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) == 3)\n\t\tC,A,B = args\n\t\tm,k1 = A['shape']\n\t\tk2,n = B['shape']\n\t\tassert (k1 == k2)\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tt3 = C['dtype']\n\t\tassert(t1 == t2 == t3)\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.C = C\n\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k1\n\t\tself.type = t1\n\t\tself.name = d.name\n\n\t\treturn\n\n\tdef tc(self):\n            for s in TC_GEMMS:\n                if s in self.name:\n                    return 1\n            return 0\n\n\tdef bytes(self):\n\t\tm, n, k = self.m, self.n, self.k\n\t\treturn Utility.typeToBytes(self.type) * (m*n + m*k + n*k)\n\n\tdef flops(self):\n\t\treturn self.m * self.n * self.k * 2\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef params(self):\n\t\tp = OrderedDict([('M',self.n),('N',self.m),('K',self.k),('type',self.type)])\n\t\treturn p\n\nclass Bmm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\") and (op == \"bmm\")\n\n\t\t#Filter out named params (kwargs)\n\t\targs = list(filter(lambda x : x['name'] == \"\", args))\n\n\t\tassert (len(args) == 2)\n\t\tA,B = args\n\t\tb1,m,k1 = A['shape']\n\t\tb2,k2,n = B['shape']\n\t\tassert (b1 == b2)\n\t\tassert (k1 == k2)\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tassert(t1 == t2)\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.b = b1\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k1\n\t\tself.type = t1\n\t\tself.name = d.name\n\n\tdef tc(self):\n            for s in TC_GEMMS:\n                if s in self.name:\n                    return 1\n            return 0\n\n\tdef params(self):\n\t\t#p = OrderedDict([('A', A['shape']), ('B', B['shape']), ('type', t1)])\n\t\tp = OrderedDict([('B',self.b), ('M',self.n),('N',self.m),('K',self.k),('type',self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn self.b * self.m * self.n * self.k * 2\n\n\tdef bytes(self):\n\t\tb, m, n, k = self.b, self.m, self.n, self.k\n\t\treturn Utility.typeToBytes(self.type) * b * (m*n + m*k + n*k)\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\nclass Matmul(OperatorLayerBase):\n\n\tNON_GEMM = [\"kernelPointwiseApply2\", \"reduce_1Block_kernel\", \"elementwise_kernel\"]\n\tNON_TC = NON_GEMM + [\"dot_kernel\"]\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.name = d.name\n\t\tself.sub = d.sub\n\n\t\tassert ((mod == \"torch\") and (op == \"matmul\")) or ((mod == \"Tensor\") and (op == \"__matmul__\"))\n\t\tassert (len(args) == 2)\n\n\t\tassert any([x in d.name for x in Matmul.NON_TC + [\"gemm\", \"gemv\"]])\n\n\t\tA,B = args\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tassert(t1 == t2)\n\n\t\tA = A['shape']\n\t\tB = B['shape']\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.type = t1\n\n\t\t# batch, MNK\n\t\tif (len(A) == 1) and (len(B) == 1):\n\t\t\t#dot product\n\t\t\tassert (A[0] == B[0])\n\t\t\tself.b = (1,)\n\t\t\tself.m = 1\n\t\t\tself.n = 1\n\t\t\tself.k = A[0]\n\n\t\telif (len(A) == 2) and (len(B) == 2):\n\t\t\t#gemm\n\t\t\tm,k1 = A\n\t\t\tk2,n = B\n\t\t\tassert(k1 == k2)\n\t\t\tself.b = (1,)\n\t\t\tself.m = m\n\t\t\tself.n = n\n\t\t\tself.k = k1\n\n\t\telif (len(A) == 1) and (len(B) == 2):\n\t\t\t#vector matrix\n\t\t\tk1 = A[0]\n\t\t\tk2,n = B\n\t\t\tassert(k1 == k2)\n\n\t\t\tself.b = (1,)\n\t\t\tself.m = 1\n\t\t\tself.n = n\n\t\t\tself.k = k1\n\n\t\telif (len(A) == 2) and (len(B) == 1):\n\t\t\t#gemv\n\t\t\tm,k1 = A\n\t\t\tk2 = B[0]\n\t\t\tassert (k1 == k2)\n\n\t\t\tself.b = (1,)\n\t\t\tself.m = m\n\t\t\tself.n = 1\n\t\t\tself.k = k1\n\n\t\telif (len(A) == 1) and (len(B) > 2):\n\t\t\tassert (A[0] == B[-2])\n\n\t\t\tself.b = B[0:-2]\n\t\t\tself.m = 1\n\t\t\tself.n = B[-1]\n\t\t\tself.k = B[-2]\n\n\t\telif (len(B) == 1) and (len(A) > 2):\n\t\t\tassert (B[0] == A[-1])\n\n\t\t\tself.b = A[0:-2]\n\t\t\tself.m = A[-2]\n\t\t\tself.n = 1\n\t\t\tself.k = A[-1]\n\n\t\telse:\n\t\t\tassert (len(A) >= 2)\n\t\t\tassert (len(B) >= 2)\n\t\t\tassert (A[-1] == B[-2])\n\t\t\tself.m = A[-2]\n\t\t\tself.n = B[-1]\n\t\t\tself.k = A[-1]\n\n\t\t\taa = np.empty(A[0:-2])\n\t\t\tbb = np.empty(B[0:-2])\n\t\t\tself.b = np.broadcast(aa, bb).shape\n\n\tdef params(self):\n\t\treturn OrderedDict([('A', self.A), ('B', self.B), ('type', self.type)])\n\n\tdef tc(self):\n\t\tif self.name in Matmul.NON_TC:\n\t\t\treturn \"-\"\n\t\telse:\n                    for s in TC_GEMMS:\n                        if s in self.name:\n                            return 1\n                    return 0\n\n\tdef bytes(self):\n\t\t# TODO: check bytes for non-GEMM cases\n\t\tif self.name in Matmul.NON_GEMM:\n\t\t\treturn 2 * Utility.typeToBytes(self.type) * Utility.numElems(self.A) #could be B as well\n\t\telse:\n\t\t\tm, n, k = self.m, self.n, self.k\n\t\t\treturn Utility.typeToBytes(self.type) * (m*n + m*k + n*k)\n\n\tdef flops(self):\n\t\t# TODO: calculate actual FLOPs. At least we're not saying it's GEMM FLOPs for now.\n\t\tif self.name in Matmul.NON_GEMM:\n\t\t\treturn 0\n\t\telse:\n\t\t\treturn Utility.numElems(self.b) * self.m * self.n * self.k * 2\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\nclass Mm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\") and (op == \"mm\")\n\t\tassert (len(args) == 2)\n\n\t\tA,B = args\n\t\tm,k1 = A['shape']\n\t\tk2,n = B['shape']\n\t\tassert (k1 == k2)\n\t\tt1 = A['dtype']\n\t\tt2 = B['dtype']\n\t\tassert(t1 == t2)\n\n\t\tself.A = A\n\t\tself.B = B\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k1\n\t\tself.type = t1\n\t\tself.name = d.name\n\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('M',self.n),('N',self.m),('K',self.k),('type',self.type)])\n\t\treturn p\n\n\tdef tc(self):\n            for s in TC_GEMMS:\n                if s in self.name:\n                    return 1\n            return 0\n\n\tdef bytes(self):\n\t\tm, n, k = self.m, self.n, self.k\n\t\treturn Utility.typeToBytes(self.type) * (m*n + m*k + n*k)\n\n\tdef flops(self):\n\t\treturn self.m * self.n * self.k * 2\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/conv.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Conv(OperatorLayerBase):\n\n\t\"\"\"\n\t# N = batch size\n\t# C,H,W = input channels, height, width\n\t# K,P,Q = output channels, height, width\n\t# R,S = filter height, width\n\t# g = groups\n\t\"\"\"\n\n\t#todo: refine winograd and FFT\n\tconvAuxList = [\"nchwToNhwc\", \"nhwcToNchw\", \"OffsetsKernel\",]\n\twinoAuxList = [\"generateWinogradTilesKernel\", \"winogradWgradData\", \"winogradWgradOutput\", \"winogradWgradDelta\"]\n\tfftAuxList = [\"compute_gemm_pointers\", \"flip_filter\", \"fft2d_r2c_\", \"fft2d_c2r_\", \"fft1d_r2c\", \"fft1d_c2r\"]\n\tmiscAuxList = [\"scaleTensor_kernel\",]\n\n\tconvList = [\"_s884cudnn_\", \"_s1688cudnn_\", \"_scudnn_\", \"2d_grouped_direct_kernel\", \"cudnn::detail::implicit_convolve_sgemm\", \"cudnn::detail::dgrad2d_alg1_1\", \"cudnn::detail::wgrad_alg0_engine\", \"cudnn::detail::dgrad_engine\", \"dgrad_1x1_stride_2x2\", \"spatialDepthwiseConvolutionUpdateOutput\"]\n\twinoList = [\"winograd3x3Kernel\", \"_sgemm_\"]\n\tfftList = [\"fermiPlusCgemmLDS128_batched\", \"_gcgemm_\",]\n\tmiscList = []\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.dir = d.dir\n\t\tself.name = d.name\n\t\tself.sub = d.sub\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op in [\"conv1d\", \"conv2d\"])\n\t\tlength = len(args)\n\t\tassert (length >= 2) and (length <= 7)\n\t\ti,w = args[0], args[1]\n\t\tassert (i['type'] == \"tensor\")\n\t\tassert (w['type'] == \"tensor\")\n\n\t\t#ignore bias\n\n\t\tif (length >= 4) and (args[3]['name'] == \"\"):\n\t\t\ts = args[3]\n\t\telif any(x['name'] == 'stride' for x in args):\n\t\t\ts = list(filter(lambda x : x['name'] == 'stride', args))[0]\n\t\telse:\n\t\t\ts = {'name': 'stride', 'type': 'int', 'value': 1}\n\n\t\tif (length >= 5) and (args[4]['name'] == \"\"):\n\t\t\tp = args[4]\n\t\telif any(x['name'] == 'padding' for x in args):\n\t\t\tp = list(filter(lambda x : x['name'] == 'padding', args))[0]\n\t\telse:\n\t\t\tp = {'name': 'padding', 'type': 'int', 'value': 0}\n\n\t\tif (length >= 6) and (args[5]['name'] == \"\"):\n\t\t\td = args[5]\n\t\telif any(x['name'] == 'dilation' for x in args):\n\t\t\td = list(filter(lambda x : x['name'] == 'dilation', args))[0]\n\t\telse:\n\t\t\td = {'name': 'dilation', 'type': 'int', 'value': 1}\n\n\t\tif (length == 7) and (args[6]['name'] == \"\"):\n\t\t\tg = args[6]\n\t\telif any(x['name'] == 'groups' for x in args):\n\t\t\tg = list(filter(lambda x : x['name'] == 'groups', args))[0]\n\t\telse:\n\t\t\tg = {'name': 'groups', 'type': 'int', 'value': 1}\n\n\t\tif op == \"conv1d\":\n\t\t\tassert (len(i['shape']) == 3)\n\t\t\tassert (len(w['shape']) == 3)\n\t\t\tassert (i['dtype'] == w['dtype'])\n\t\t\tN, C1, W = i['shape']\n\t\t\tK, C2, S = w['shape']\n\t\t\tassert (C1 == C2)\n\t\t\tp = p['value'] if Utility.isscalar(p['type']) else p['value'][0]\n\t\t\ts = s['value'] if Utility.isscalar(s['type']) else s['value'][0]\n\t\t\td = d['value'] if Utility.isscalar(d['type']) else d['value'][0]\n\t\t\tg = g['value']\n\t\t\tassert (g == 1)\n\t\t\tH = 1\n\t\t\tR = 1\n\n\t\t\tP = 1 + (H - (((R-1))+1))\n\t\t\tQ = 1 + (W + 2*p - (((S-1)*d)+1))/s\n\t\t\tP = int(P)\n\t\t\tQ = int(Q)\n\t\t\tif (H == 1):\n\t\t\t\tassert (P == 1)\n\t\t\tif (W == 1):\n\t\t\t\tassert (Q == 1)\n\n\t\t\tself.N = N\n\t\t\tself.C = C1\n\t\t\tself.H = H\n\t\t\tself.W = W\n\t\t\tself.K = K\n\t\t\tself.P = P\n\t\t\tself.Q = Q\n\t\t\tself.R = R\n\t\t\tself.S = S\n\t\t\tself.ph = 0\n\t\t\tself.pw = p\n\t\t\tself.U = 1\n\t\t\tself.V = s\n\t\t\tself.dh = 1\n\t\t\tself.dw = d\n\t\t\tself.g = g\n\t\t\tself.type = i['dtype']\n\n\t\telif op == \"conv2d\":\n\t\t\tassert (len(i['shape']) == 4)\n\t\t\tassert (len(w['shape']) == 4)\n\t\t\tassert (i['dtype'] == w['dtype'])\n\t\t\tN, C1, H, W = i['shape']\n\t\t\tK, C2, R, S = w['shape']\n\n\t\t\tif Utility.isscalar(p['type']):\n\t\t\t\tph = pw = p['value']\n\t\t\telse:\n\t\t\t\tassert (p['type'] == \"tuple\")\n\t\t\t\tph, pw = p['value']\n\n\t\t\tif Utility.isscalar(s['type']):\n\t\t\t\tsh = sw = s['value']\n\t\t\telse:\n\t\t\t\tassert (s['type'] == \"tuple\")\n\t\t\t\tsh, sw = s['value']\n\n\t\t\tif Utility.isscalar(d['type']):\n\t\t\t\tdh = dw = d['value']\n\t\t\telse:\n\t\t\t\tassert (d['type'] == \"tuple\")\n\t\t\t\tdh, dw = d['value']\n\n\t\t\tg = g['value']\n\t\t\tassert (g >= 1)\n\t\t\tassert (C1 == C2*g)\n\n\t\t\tP = 1 + (H + 2*ph - (((R-1)*dh)+1))/sh\n\t\t\tQ = 1 + (W + 2*pw - (((S-1)*dw)+1))/sw\n\t\t\tP = int(P)\n\t\t\tQ = int(Q)\n\t\t\tif (H == 1):\n\t\t\t\tassert (P == 1)\n\t\t\tif (W == 1):\n\t\t\t\tassert (Q == 1)\n\n\t\t\tself.N = N\n\t\t\tself.C = C1\n\t\t\tself.H = H\n\t\t\tself.W = W\n\t\t\tself.K = K\n\t\t\tself.P = P\n\t\t\tself.Q = Q\n\t\t\tself.R = R\n\t\t\tself.S = S\n\t\t\tself.ph = ph\n\t\t\tself.pw = pw\n\t\t\tself.U = sh\n\t\t\tself.V = sw\n\t\t\tself.dh = dh\n\t\t\tself.dw = dw\n\t\t\tself.g = g\n\t\t\tself.type = i['dtype']\n\n\t\telse:\n\t\t\tassert False\n\n\tdef params(self):\n\t\tp = OrderedDict([('N',self.N), ('C',self.C), ('H',self.H), ('W',self.W), ('K',self.K), ('P',self.P), ('Q',self.Q), ('R',self.R), ('S',self.S), ('ph',self.ph), ('pw',self.pw), ('U',self.U), ('V',self.V), ('dh',self.dh), ('dw',self.dw), ('g',self.g), ('type',self.type)])\n\t\treturn p\n\n\tdef conv_bytes_flops(self, N, C, H, W, K, P, Q, R, S, g, t):\n\t\tf = 2*N*K*P*Q*C*R*S/g #for fprop\n\t\telems = N*C*H*W + K*C*R*S/g + N*K*P*Q\n\t\tb = elems * Utility.typeToBytes(t)\n\t\treturn b,f\n\n\tdef bytes_flops(self):\n\t\tN,C,H,W,K,P,Q,R,S,ph,pw,U,V,dh,dw,g,t = self.params().values()\n\n\t\tif any(x in self.name for x in Conv.convAuxList+Conv.winoAuxList+Conv.fftAuxList+Conv.miscAuxList):\n\t\t\tbytes, flops = [0, 0]\n\n\t\telif any(x in self.name for x in Conv.convList+Conv.winoList+Conv.fftList+Conv.miscList):\n\t\t\tif g == 1:\n\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t)\n\t\t\telse:\n\t\t\t\tif \"2d_grouped_direct_kernel\" in self.name:\t#only 1 kernel is called\n\t\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t)\n\t\t\t\telif \"spatialDepthwiseConvolutionUpdateOutput\" in self.name: #one kernel for separable conv\n\t\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t)\n\t\t\t\telse:\t#a kernel per group is called\n\t\t\t\t\tbytes, flops = self.conv_bytes_flops(N,C/g,H,W,K/g,P,Q,R,S,1,t)\n\n\t\telif (\"calc_bias_diff\" in self.name):\t#bias gradient\n\t\t\telems = N*K*P*Q\n\t\t\tflops = elems\n\t\t\tbytes = 2 * elems * Utility.typeToBytes(t)\n\t\t\t#params = OrderedDict([('N',N), ('K',K), ('P',P), ('Q',Q), ('type', t)])\n\n\t\telse:\n\t\t\tbytes, flops = [0, 0]\n\n\t\treturn bytes, flops\n\n\tdef bytes(self):\n\t\tb,_ = self.bytes_flops()\n\t\treturn b\n\n\tdef flops(self):\n\t\t_,f = self.bytes_flops()\n\t\treturn f\n\n\tdef tc(self):\n\t\tfor s in [\"884cudnn\", \"1688cudnn\"]:\n\t\t\tif s in self.name:\n\t\t\t\treturn 1\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/convert.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Convert(OperatorLayerBase):\n\t\"\"\"\n\tClass to handle convert operations.\n\t\"\"\"\n\tops = [\"byte\", \"char\", \"double\", \"float\", \"half\", \"int\", \"long\", \"short\", \"to\"]\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op in Convert.ops)\n\t\tassert (len(args) == 1)\n\n\t\t#The argument could be a tensor or scalar\n\t\tt = args[0]\n\t\tif t['type'] == \"tensor\":\n\t\t\tshape = t['shape']\n\t\t\tstype = t['dtype']\n\t\telse:\n\t\t\tshape = (1,)\n\t\t\tstype = t['type']\n\t\tif self.op_ == \"to\":\n\t\t\top = stype\n\n\t\tself.shape = shape\n\t\tself.stype = stype\n\t\tself.dtype = op\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('stype', self.stype), ('dtype', self.dtype)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\tb = self.elems() * (Utility.typeToBytes(self.stype) + Utility.typeToBytes(self.dtype))\n\t\treturn b\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/data.py",
    "content": "from .utility import Utility\n\nclass Data(object):\n\t\"\"\"\n\tClass to store all the data for every kernel e.g. name, bytes, flops, device, stream etc.\n\t\"\"\"\n\tdef __init__(self, kernel):\n\t\t#Available from NVprof\n\t\tself.tid = kernel['tid']\n\t\tself.device = kernel['device']\n\t\tself.stream = kernel['stream']\n\t\tself.grid = str(kernel['grid']).replace(\" \",\"\").replace(\"(\",\"\").replace(\")\",\"\")\n\t\tself.block = str(kernel['block']).replace(\" \",\"\").replace(\"(\",\"\").replace(\")\",\"\")\n\t\tself.name = kernel['kShortName'].replace(\" \",\"_\")\n\t\tself.lName = kernel['kLongName']\n\t\tself.sil = kernel['kDuration']\t#units ns\n\n\t\tself.index = None\n\n\t\t#Markers\n\t\tself.argMarker = kernel['marker']\n\t\tself.modMarker = kernel['reprMarkers']\n\t\tself.seqMarker = kernel['seqMarker']\n\n\t\tself.layer = kernel['layer']\n\t\tself.trace = kernel['trace']\n\n\t\tself.seqId = kernel['seqId']\n\t\tself.altSeqId = kernel['altSeqId']\n\n\t\tself.dir = kernel['dir']\n\t\tself.sub = kernel['subSeqId']\n\n\t\tself.mod = \"na\"\n\t\tself.op = \"na\"\n\t\tself.params = {\"na\":\"na\"}\n\t\tself.tc = \"na\"\n\t\tself.flops = 0\n\t\tself.bytes = 0\n\n\tdef setParams(self, params):\n\t\t#Remove space from params\n\t\tqaz = \"\"\n\t\tfor key,value in params.items():\n\t\t\tif \"type\" not in key:\n\t\t\t\tqaz += \"{}={},\".format(key,value)\n\t\t\telse:\n\t\t\t\tif type(value) is str:\n\t\t\t\t\tqaz += \"{},\".format(Utility.typeToString(value))\n\t\t\t\telse:\n\t\t\t\t\tqaz += \"{}\".format(value)\n\n\t\tself.params = qaz.replace(\" \", \"\")\n\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/dropout.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Dropout(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"dropout\")\n\t\t#assert (len(args) == 1)\n\n\t\tself.shape = args[0]['shape']\n\t\tself.type  = args[0]['dtype']\n\t\tself.dir = d.dir\n\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\t#Ignoring the cost of writing and reading the mask\n\t\treturn Utility.typeToBytes(self.type) * self.elems() * 2\n\n\tdef flops(self):\n\t\t# Note: This is approximate and depends on the RNG\n\t\treturn 5*self.elems()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/embedding.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Embedding(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"embedding\")\n\n\t\tself.ishape = args[0]['shape']\n\t\tself.itype = args[0]['dtype']\n\n\t\tself.eshape = args[1]['shape']\n\t\tself.etype = args[1]['dtype']\n\n\t\tassert (len(self.eshape) == 2)\n\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('I', self.ishape), ('itype', self.itype), ('E', self.eshape), ('etype', self.etype)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef bytes(self):\n\t\tishape = self.ishape\n\t\titype = self.itype\n\t\teshape = self.eshape\n\t\tetype = self.etype\n\n\t\tielems = Utility.numElems(ishape)\n\n\t\tb = 0\n\t\tif self.dir == \"fprop\":\n\t\t\t#indices\n\t\t\tb += ielems * Utility.typeToBytes(itype)\n\t\t\t#read and write the embedding matrix\n\t\t\tb += ielems * eshape[1] * 2 * Utility.typeToBytes(etype)\n\t\telse:\n\t\t\t#3 times the size of the incoming gradient\n\t\t\tb = ielems * eshape[1] * 3 * Utility.typeToBytes(etype)\n\n\t\t\tif self.sub > 0:\n\t\t\t\tb = 0\n\n\t\treturn b\n\n\tdef flops(self):\n\t\t# Note: not implemented yet\n\t\treturn 0\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/index_slice_join_mutate.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nimport numpy as np\nfrom .base import OperatorLayerBase\n\nclass Cat(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\")\n\t\tassert (op == \"cat\")\n\t\tassert (len(args) >= 2)\n\n\t\tt = args[0]['dtype']\n\t\tshapes = []\n\n\t\tfor arg in args:\n\t\t\tif arg['type'] == \"tensor\":\n\t\t\t\tassert (arg['dtype'] == t)\n\t\t\t\tshapes.append(arg['shape'])\n\n\t\tself.type = t\n\t\tself.shapes = shapes\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shapes), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\tb = 0\n\t\tfor s in self.shapes:\n\t\t\tb += Utility.numElems(s)\n\t\treturn 2 * b * Utility.typeToBytes(self.type)\n\nclass Reshape(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"reshape\")\n\n\t\t#Temporarily commenting three lines\n\t\t#assert (len(args) == 2)\n\t\t#t,s = args\n\t\t#assert s['type'] == \"tuple\"\n\n\t\tt = args[0]\n\t\tassert t['type'] == \"tensor\"\n\t\tself.type = t['dtype']\n\t\tself.shape = t['shape']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\treturn 0\n\nclass Gather(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\") or (mod == \"torch\")\n\t\tassert (op == \"gather\")\n\n\t\t#Filter out the \"out\" parameter\n\t\targs = list(filter(lambda x : x['name'] != 'out', args))\n\t\tassert (len(args) == 3)\n\n\t\t#Get input\n\t\tif (args[0]['name'] == \"\"):\n\t\t\targ = args[0]\n\t\telse:\n\t\t\targ = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tassert (arg['type'] == \"tensor\")\n\n\t\tself.shape = arg['shape']\n\t\tself.type = arg['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\treturn 2 * Utility.numElems(self.shape) * Utility.typeToBytes(self.type)\n\nclass MaskedScatter(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"masked_scatter_\")\n\t\tassert (len(args) == 3)\n\n\t\tdst, mask, src = args\n\t\tassert (dst['type'] == mask['type'] == src['type'] == \"tensor\")\n\t\tassert (mask['dtype'] == \"uint8\")\n\t\tassert (dst['dtype'] == src['dtype'])\n\t\tassert (dst['shape'] == mask['shape'])\n\n\t\tself.shape = dst['shape']\n\t\tself.type = dst['dtype']\n\t\tself.seqId = d.seqId\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\telems = Utility.numElems(self.shape)\n\n\t\t#src and dst\n\t\tb = 2 * elems * Utility.typeToBytes(self.type)\n\n\t\t#mask (uint8)\n\t\tb += elems\n\n\t\tif (self.seqId > 0):\n\t\t\tb = 0\n\t\treturn b\n\nclass Nonzero(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"nonzero\")\n\t\tassert (len(args) == 1)\n\n\t\targ = args[0]\n\t\tself.shape = arg['shape']\n\t\tself.type = arg['dtype']\n\t\tself.seqId = d.seqId\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\telems = Utility.numElems(self.shape)\n\t\tdim = len(self.shape)\n\n\t\t#input tensor\n\t\tb = elems * Utility.typeToBytes(self.type)\n\n\t\t#in the worst case, the output is a (elems x dim) tensor of type \"long\"\n\t\tb += elems * dim * Utility.typeToBytes(\"int64\")\n\n\t\tif self.seqId > 0:\n\t\t\treturn 0\n\t\telse:\n\t\t\treturn b\n\nclass IndexSelect(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\") or (mod == \"torch\")\n\t\tassert (op == \"index_select\")\n\n\t\t#Filter out the \"out\" parameter\n\t\targs = list(filter(lambda x : x['name'] != 'out', args))\n\t\tassert (len(args) == 3)\n\n\t\t#Get input, dim and index\n\t\tif (args[0]['name'] == \"\"):\n\t\t\tt = args[0]\n\t\telse:\n\t\t\tt = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tif (args[1]['name'] == \"\"):\n\t\t\td = args[1]\n\t\telse:\n\t\t\td = list(filter(lambda x : x['name'] == \"dim\", args))[0]\n\n\t\tif (args[2]['name'] == \"\"):\n\t\t\ti = args[2]\n\t\telse:\n\t\t\ti = list(filter(lambda x : x['name'] == \"index\", args))[0]\n\n\t\tassert (t['type'] == i['type'] == \"tensor\")\n\t\tassert (d['type'] == \"int\")\n\t\tassert (i['dtype'] == \"int64\")\n\t\tassert (len(i['shape']) == 1)\n\n\t\tshape = t['shape']\n\t\tdim = d['value']\n\t\tindices = i['shape'][0]\n\t\tassert (dim < len(shape))\n\n\t\tself.shape = shape\n\t\tself.dim = dim\n\t\tself.indices = indices\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape),('D', self.dim),('I', self.indices),('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\t#determine the shape of the output tensor\n\t\tshape = list(self.shape)\n\t\tshape[self.dim] = self.indices\n\n\t\tb = 0\n\n\t\t#time to read the input and write the output\n\t\telems = Utility.numElems(shape)\n\t\tb += 2 * elems * Utility.typeToBytes(self.type)\n\n\t\t#time to read the indices\n\t\tb += self.indices * Utility.typeToBytes(\"int64\")\n\n\t\treturn b\n\nclass MaskedSelect(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\t\tself.sub = d.sub\n\n\t\tassert (mod == \"Tensor\") or (mod == \"torch\")\n\t\tassert (op == \"masked_select\")\n\n\t\t#Filter out the \"out\" parameter\n\t\targs = list(filter(lambda x : x['name'] != 'out', args))\n\t\tassert (len(args) == 2)\n\n\t\t#Get input and mask\n\t\tif (args[0]['name'] == \"\"):\n\t\t\tt = args[0]\n\t\telse:\n\t\t\tt = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tif (args[1]['name'] == \"\"):\n\t\t\tm = args[1]\n\t\telse:\n\t\t\tm = list(filter(lambda x : x['name'] == \"mask\", args))[0]\n\n\t\tassert (m['dtype'] == \"uint8\")\n\n\t\ttensor = t['shape']\n\t\tmask = m['shape']\n\n\t\t#check for broadcast condition\n\t\tif (tensor != mask):\n\t\t\tarray1 = np.empty(list(tensor))\n\t\t\tarray2 = np.empty(list(mask))\n\t\t\ttry:\n\t\t\t\tout = np.broadcast(array1, array2).shape\n\t\t\texcept:\n\t\t\t\tassert False\n\n\t\tself.tshape = tensor\n\t\tself.mshape = mask\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.tshape),('M', self.mshape),('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\ttensor = self.tshape\n\t\tmask = self.mshape\n\t\tt = self.type\n\n\t\t#in the worst case, #output elements = #input elements\n\t\tb = 2 * Utility.numElems(tensor) * Utility.typeToBytes(t)\n\n\t\t#mask tensor (assuming uint8)\n\t\tb += Utility.numElems(mask)\n\t\treturn b\n\n\tdef flops(self):\n\t\treturn 0\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/linear.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Linear(OperatorLayerBase):\n\n\t'''\n\tNotes:\n\tIf the bias occurs before the GEMM, then its 1 write (bias expansion).\n\tIf the bias occurs after, then its 1 read and 1 write.\n\tbias in bprop is a reduction and hence is 1 read.\n\t'''\n\n\tgemmKernels = [\"gemm\", \"gemv\", \"dot_kernel\", \"splitKreduce_kernel\", \"reduce_1Block_kernel\"]\n\tbiasKernels = [\"kernelReduceContigDim\", \"kernelReduceNoncontigDim_shared\", \"elementwise_kernel\", \"reduce_kernel\"]\n\n\tdef setXWBMNK(self, args):\n\t\tx = None\n\t\tw = None\n\t\tb = None\n\t\tif (len(args) == 2):\n\t\t\tx,w = args\n\t\telif (len(args) == 3):\n\t\t\tx,w,b = args\n\t\t\tassert (x['type'] == w['type'] == \"tensor\")\n\t\t\tif (b['type'] == \"tensor\"):\n\t\t\t\tassert(len(b['shape']) == 1)\n\t\t\telif (b['type'] == \"NoneType\"):\n\t\t\t\tassert b['value'] is None\n\t\t\t\tb = None\n\t\t\telse:\n\t\t\t\tassert False\n\t\telse:\n\t\t\tassert False\n\n\t\tassert(len(w['shape']) == 2)\n\t\tk1 = x['shape'][-1]\n\t\tn,k2 = w['shape']\n\t\tassert(k1 == k2)\n\t\tif b is not None:\n\t\t\tassert(b['shape'][0] == n)\n\t\tt1 = x['dtype']\n\t\tt2 = w['dtype']\n\t\tassert(t1 == t2)\n\n\t\t# X, W, B\n\t\tself.x = x['shape']\n\t\tself.w = w['shape']\n\t\tself.b = b['shape'] if b is not None else None\n\t\tself.type = t1\n\n\t\t# M, N, K\n\t\t#n = Utility.numElems(x[0:-1])\n\t\tn = self.x[0:-1]\n\t\tk = self.x[-1]\n\t\tm,k1 = self.w\n\t\tassert (k == k1)\n\n\t\tself.m = m\n\t\tself.n = n\n\t\tself.k = k\n\n\tdef tc(self):\n\t\tif self.op() == \"linear\":\n\t\t\treturn 1 if \"884gemm\" in self.name else 0\n\t\telse:\n\t\t\treturn \"-\"\n\n\tdef __init__(self, d):\n\t\tself.name = d.name\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"linear\")\n\n\t\tself.setXWBMNK(args)\n\n\t\tif any(x in d.name for x in Linear.gemmKernels):\n\t\t\tself.op_ = \"linear\"\n\t\telse:\n\t\t\tassert (d.name in Linear.biasKernels)\n\t\t\tself.op_ = \"bias\"\n\n\t\t'''\n\t\telif ((\"kernelPointwiseApply2\" in d.name) or (\"kernelReduceContigDim\" in d.name) or (\"kernelReduceNoncontigDim_shared\" in d.name)):\n\t\t\t#bias expansion was before the gemm\n\t\t\tself.op_ = \"bias\"\n\n\t\telif (\"elementwise_kernel\" in d.name):\n\t\t\t#Bias addition happens later with a broadcast tensor\n\t\t\tself.op_ = \"bias\"\n\t\t\tassert (len(d.argMarker) == 2)\n\t\t\tmarker = eval(d.argMarker[1])\n\t\t\tmod = marker['mod']\n\t\t\top = marker['op']\n\t\t\targs = marker['args']\n\n\t\t\tassert (mod == \"Tensor\")\n\t\t\tassert (op == \"__iadd__\")\n\t\t\tassert (len(args) == 2)\n\t\t\tmn = args[0]['shape']\n\t\t\tb = args[1]['shape']\n\t\t\tassert (len(b) == 1)\n\n\t\t\tassert (mn == (self.n + (self.m,)))\n\t\t\tassert (b == self.b)\n\n\t\telse:\n\t\t\tassert False\n\t\t'''\n\n\tdef params(self):\n\t\t#p = OrderedDict([('X', self.x), ('W', self.w), ('B', self.b), ('type', self.type)])\n\n\t\tm, n, k, x, w, t = self.m, self.n, self.k, self.x, self.w, self.type\n\t\tif len(n) == 1:\n\t\t\tn = n[0]\n\n\t\tif self.op_ == \"linear\":\n\t\t\tif self.dir == \"fprop\":\n\t\t\t\tp = OrderedDict([('M', m), ('N', n), ('K', k), ('type', t)])\n\t\t\telif self.dir == \"bprop\":\n\t\t\t\tif self.sub == 0:\t\t#dgrad (most likely)\n\t\t\t\t\tp = OrderedDict([('M', k), ('N', n), ('K', m), ('type', t)])\n\t\t\t\telif self.sub == 1:\t#wgrad (most likely)\n\t\t\t\t\tp = OrderedDict([('M', k), ('N', m), ('K', n), ('type', t)])\n\t\t\t\telse:\n\t\t\t\t\t#This happens when there are additional kernels for reduction\n\t\t\t\t\tp = OrderedDict([('X', x), ('W', w), ('type', t)])\n\t\t\telse:\n\t\t\t\tassert False\n\n\t\telif self.op_ == \"bias\":\n\t\t\tp = OrderedDict([('M', m), ('N', n), ('type', t)])\n\t\telse:\n\t\t\tassert False\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef bytesFlops(self):\n\n\t\tm = self.m\n\t\tn = Utility.numElems(self.n)\n\t\tk = self.k\n\n\t\tif self.op_ == \"linear\":\n\t\t\tif self.dir == \"fprop\":\n\t\t\t\tf = m * n * k * 2\n\t\t\t\tb = m*n + m*k + n*k * Utility.typeToBytes(self.type)\n\t\t\telif self.dir == \"bprop\":\n\t\t\t\tif self.sub == 0:\t\t#dgrad (most likely)\n\t\t\t\t\tf = m * n * k * 2\n\t\t\t\t\tb = m*n + m*k + n*k * Utility.typeToBytes(self.type)\n\t\t\t\telif self.sub == 1:\t#wgrad (most likely)\n\t\t\t\t\tf = m * n * k * 2\n\t\t\t\t\tb = m*n + m*k + n*k * Utility.typeToBytes(self.type)\n\t\t\t\telse:\n\t\t\t\t\t#This happens when there are additional kernels for reduction\n\t\t\t\t\tf = 0\n\t\t\t\t\tb = 0\n\t\t\telse:\n\t\t\t\tassert False\n\n\t\telif self.op_ == \"bias\":\n\t\t\tf = m * n\n\t\t\tb = 2 * m * n * Utility.typeToBytes(self.type)\n\t\telse:\n\t\t\tassert False\n\t\treturn b,f\n\n\tdef bytes(self):\n\t\tb, f = self.bytesFlops()\n\t\treturn b\n\n\tdef flops(self):\n\t\tb, f = self.bytesFlops()\n\t\treturn f\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/loss.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\n#TODO: Add support for additional loss functions.\n\nclass MSELoss(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"mse_loss\")\n\t\tassert (len(args) == 3)\n\n\t\t#Get input, target and reduction\n\t\tif (args[0]['name'] == \"\"):\n\t\t\tx = args[0]\n\t\telse:\n\t\t\tx = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tif (args[1]['name'] == \"\"):\n\t\t\ty = args[1]\n\t\telse:\n\t\t\ty = list(filter(lambda x : x['name'] == \"target\", args))[0]\n\n\t\tif (args[2]['name'] == \"\"):\n\t\t\tr = args[2]\n\t\telse:\n\t\t\tr = list(filter(lambda x : x['name'] == \"reduction\", args))[0]\n\n\t\tassert (x['type'] == y['type'] == \"tensor\")\n\t\tassert (x['shape'] == y['shape'])\n\t\tassert (x['dtype'] == y['dtype'])\n\t\tassert (r['type'] == \"str\")\n\t\tassert (r['value'] in [\"none\", \"mean\", \"sum\"])\n\n\t\tself.shape = x['shape']\n\t\tself.type = x['dtype']\n\t\tself.red = r['value']\n\t\tself.dir = d.dir\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type), ('red', self.red)])\n\t\treturn p\n\n\tdef elems(self):\n\t\tred = self.red\n\t\te = Utility.numElems(self.shape)\n\n\t\tif self.dir == \"fprop\":\n\t\t\tif red == \"none\":\n\t\t\t\te *= 3\n\t\t\telse:\n\t\t\t\te *= 2\n\t\telse:\n\t\t\tif red == \"none\":\n\t\t\t\te *= 4\n\t\t\telse:\n\t\t\t\te *= 3\n\t\treturn e\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\n\tdef flops(self):\n\t\treturn self.elems() * 2 + 1\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/misc.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Foo(OperatorLayerBase):\n\t\"\"\"\n\tAn object of Foo is instantiated when we detect an unsupported operator.\n\t\"\"\"\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tshapes = []\n\t\ttypes = []\n\n\t\tfor arg in args:\n\t\t\tif arg['type'] == \"tensor\":\n\t\t\t\tshapes.append(arg['shape'])\n\t\t\t\ttypes.append(arg['dtype'])\n\n\t\tself.shape = shapes\n\t\tself.type = types\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\treturn 0\n\nclass Copy(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"copy_\")\n\t\tassert (len(args) == 2)\n\n\t\tdst, src = args\n\t\tassert (src['type'] == dst['type'])\n\t\tassert (src['shape'] == dst['shape'])\n\n\t\tself.shape = src['shape']\n\t\tself.stype = src['dtype']\n\t\tself.dtype = dst['dtype']\n\n\tdef params(self):\n\t\t#The data type might be different\n\t\tp = OrderedDict([('T', self.shape), ('stype', self.stype), ('dtype', self.dtype)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\treturn self.elems() * (Utility.typeToBytes(self.stype) + Utility.typeToBytes(self.dtype))\n\nclass Clone(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"clone\")\n\t\tassert (len(args) == 1)\n\t\tt = args[0]\n\t\tself.shape = t['shape']\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\treturn 2 * self.elems() * Utility.typeToBytes(self.type)\n\nclass Contiguous(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"contiguous\")\n\t\tassert (len(args) == 1)\n\t\tt = args[0]\n\t\tself.shape = t['shape']\n\t\tself.type = t['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\treturn 2 * Utility.numElems(self.shape) * Utility.typeToBytes(self.type)\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\nclass Any(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"Tensor\")\n\t\tassert (op == \"any\")\n\t\tassert (len(args) == 1)\t#could be 2 as well, the second argument is a bool\n\t\tt = args[0]\n\n\t\tself.shape = t['shape']\n\t\tself.type = t['dtype']\n\t\tself.sub = d.sub\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\treturn Utility.numElems(self.shape) * Utility.typeToBytes(self.type)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/normalization.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass BatchNorm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (op == \"batch_norm\")\n\t\tassert (len(args) == 8)\n\t\ti = args[0]\n\t\tassert (i['type'] == \"tensor\")\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\t\tself.dir = d.dir\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Variance algo-dependent, but this is a reasonable value.\n\t\treturn self.elems() * 8\n\n\tdef bytes(self):\n\t\te = self.elems()\n\t\tif self.dir == \"fprop\":\n\t\t\te *= 4\n\t\telse:\n\t\t\te *= 5\n\n\t\treturn e * Utility.typeToBytes(self.type)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/optim.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\n#TODO: Add support for other optimizers.\n\nclass Adam(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert(op == \"adam\")\n\t\tassert (len(args) == 12) or (len(args) == 14)\n\t\tw, hw, m, v, g = args[0:5]\n\t\tassert (w['shape'] == m['shape'] == v['shape'] == g['shape'])\n\t\tassert (hw['shape'] == w['shape']) or (hw['shape'] == (0,))\t\t#hw could be null\n\t\tassert (w['type'] == m['type'] == v['type'] == g['type'] == hw['type'] == \"tensor\")\n\t\tassert (w['dtype'] == m['dtype'] == v['dtype'] == \"float32\")\n\n\t\tself.w = w\n\t\tself.g = g\n\n\tdef params(self):\n\t\tp = OrderedDict([('T',self.w['shape']), ('wtype',self.w['dtype']), ('gtype',self.g['dtype'])])\n\t\treturn p\n\n\tdef flops(self):\n\t\treturn 0\n\n\tdef bytes(self):\n\t\twshape = self.w['shape']\n\t\twtype = self.w['dtype']\n\t\tgtype = self.g['dtype']\n\t\tb = 0\n\n\t\telems = Utility.numElems(wshape)\n\n\t\t#Get time to stream read/write w, m, v\n\t\tb += 6 * elems *  Utility.typeToBytes(wtype)\n\n\t\t#Get time to read \"g\"\n\t\tb += elems * Utility.typeToBytes(gtype)\n\n\t\tif wtype != gtype: #mixed precision\n\t\t\t#Get time to write \"hw\n\t\t\tb += elems * Utility.typeToBytes(gtype)\n\n\t\treturn b\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/output.py",
    "content": "import errno, os, sys\n\nclass Output():\n\t\"\"\"\n\tThis class handles printing of a columed output and a CSV.\n\t\"\"\"\n\n\t# The table below is organized as \n\t# user_option: [output_header, attribute_in_Data_class, type, min_width_in_columed_output]\n\ttable = {\n\t\t\"idx\":\t\t[\"Idx\",\t\t\t\"index\",\tint,\t7],\n\t\t\"seq\":\t\t[\"SeqId\",\t\t\"seqId\",\tstr,\t7],\n\t\t\"altseq\":\t[\"AltSeqId\",\t\"altSeqId\",\tstr,\t7],\n\t\t\"tid\":\t\t[\"TId\",\t\t\t\"tid\",\t\tint,\t12],\n\t\t\"layer\":\t[\"Layer\", \t\t\"layer\",\tstr,\t10],\n\t\t\"trace\":\t[\"Trace\",\t\t\"trace\",\tstr,\t25],\n\t\t\"dir\":\t\t[\"Direction\",\t\"dir\",\t\tstr,\t5],\n\t\t\"sub\":\t\t[\"Sub\",\t\t\t\"sub\",\t\tint,\t3],\n\t\t\"mod\":\t\t[\"Module\",\t\t\"mod\",\t\tstr,\t15],\n\t\t\"op\":\t\t[\"Op\",\t\t\t\"op\",\t\tstr,\t15],\n\t\t\"kernel\":\t[\"Kernel\",\t\t\"name\",\t\tstr,\t0],\n\t\t\"params\":\t[\"Params\",\t\t\"params\",\tstr,\t0],\n\t\t\"sil\":\t\t[\"Sil(ns)\",\t\t\"sil\",\t\tint,\t10],\n\t\t\"tc\":\t\t[\"TC\",\t\t\t\"tc\",\t\tstr,\t2],\n\t\t\"device\":\t[\"Device\",\t\t\"device\",\tint,\t3],\n\t\t\"stream\":\t[\"Stream\",\t\t\"stream\",\tint,\t3],\n\t\t\"grid\":\t\t[\"Grid\",\t\t\"grid\",\t\tstr,\t12],\n\t\t\"block\":\t[\"Block\",\t\t\"block\",\tstr,\t12],\n\t\t\"flops\":\t[\"FLOPs\", \t\t\"flops\",\tint,\t12],\n\t\t\"bytes\":\t[\"Bytes\",\t\t\"bytes\", \tint,\t12]\n\t}\n\n\tdef __init__(self, args):\n\t\tself.cols = args.c\n\t\tself.csv = args.csv\n\t\tself.col = True if (args.w > 0) else False\n\t\tself.width = args.w\n\n\t\tw = 0\n\t\tfor col in self.cols:\n\t\t\tassert col in Output.table.keys()\n\t\t\tw += Output.table[col][3]\n\n\t\tif ((self.col) and (w > self.width)):\n\t\t\tprint(\"Minimum width required to print {} = {}. Exiting.\".format(\",\".join(self.cols), w))\n\t\t\tsys.exit(1)\n\n\t\tremainder = self.width - w\n\n\t\tif (\"kernel\" in self.cols) and (\"params\" in self.cols):\n\t\t\tOutput.table[\"kernel\"][3] = int(remainder/2)\n\t\t\tOutput.table[\"params\"][3] = int(remainder/2)\n\t\telif (\"kernel\" in self.cols):\n\t\t\tOutput.table[\"kernel\"][3] = remainder\n\t\telif (\"params\" in self.cols):\n\t\t\tOutput.table[\"params\"][3] = remainder\n\n\t\t#header format\n\t\tcadena = \"\"\n\t\tfor col in self.cols:\n\t\t\t_,_,t,w = Output.table[col]\n\t\t\tcadena += \"%-{}.{}s \".format(w,w)\n\n\t\tself.hFormat = cadena\n\n\t\t#data format\n\t\tcadena = \"\"\n\t\tfor col in self.cols:\n\t\t\t_,_,t,w = Output.table[col]\n\t\t\tif (t == str):\n\t\t\t\tcadena += \"%-{}.{}s \".format(w,w)\n\t\t\telif (t == int):\n\t\t\t\tcadena += \"%{}d \".format(w)\n\n\t\tself.dFormat = cadena\n\n\tdef foo(self, cadena, pformat):\n\t\tif self.csv:\n\t\t\tcadena = \",\".join(map(lambda x : '\"' + str(x) + '\"', cadena))\n\t\telif self.col:\n\t\t\tcadena = pformat % cadena\n\t\telse:\n\t\t\tcadena = \" \".join(map(str,cadena))\n\n\t\ttry:\n\t\t\tprint(cadena)\n\t\texcept IOError as e:\n\t\t\t#gracefully handle pipes\n\t\t\tif e.errno == errno.EPIPE:\n\t\t\t\t# Python flushes standard streams on exit; redirect remaining output\n\t\t\t\t# to devnull to avoid another BrokenPipeError at shutdown\n\n\t\t\t\tdevnull = os.open(os.devnull, os.O_WRONLY)\n\t\t\t\tos.dup2(devnull, sys.stdout.fileno())\n\t\t\t\tsys.exit(0)\n\t\t\telse:\n\t\t\t\tsys.exit(-1)\n\n\tdef header(self):\n\t\tcadena = ()\n\t\tfor col in self.cols:\n\t\t\th = Output.table[col][0]\n\t\t\tcadena = cadena + (h,)\n\n\t\tself.foo(cadena, self.hFormat)\n\n\tdef data(self, a):\n\t\tif a.dir == \"\":\n\t\t\tdirec = \"na\"\n\t\telse:\n\t\t\tdirec = a.dir\n\n\t\tif a.op == \"\":\n\t\t\top = \"na\"\n\t\telse:\n\t\t\top = a.op\n\n\t\tif a.mod == \"\":\n\t\t\tmod = \"na\"\n\t\telse:\n\t\t\tmod = a.mod\n\n\t\tcadena = ()\n\t\tfor col in self.cols:\n\t\t\tattr = Output.table[col][1]\n\t\t\tval = getattr(a, attr)\n\n\t\t\tif col == \"layer\":\n\t\t\t\tassert(type(val) == list)\n\t\t\t\tval = \":\".join(val)\n\t\t\t\tval = \"-\" if val == \"\" else val\n\n\t\t\tif col == \"trace\":\n\t\t\t\tassert(type(val) == list)\n\t\t\t\tif self.col and len(val):\n\t\t\t\t\tval = val[-1]\n\t\t\t\t\tval = val.split(\"/\")[-1]\n\t\t\t\telse:\n\t\t\t\t\tval = \",\".join(val)\n\t\t\t\t\tval = \"-\" if val == \"\" else val\n\n\t\t\tif col in [\"seq\", \"altseq\"]:\n\t\t\t\tassert(type(val) == list)\n\t\t\t\tval = \",\".join(map(str,val))\n\t\t\t\tval = \"-\" if val == \"\" else val\n\n\t\t\tcadena = cadena + (val,)\n\t\n\t\tself.foo(cadena, self.dFormat)\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/pointwise.py",
    "content": "import numpy as np\nfrom collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Pointwise(OperatorLayerBase):\n\n\tops = []\n\tops += [\"__abs__\", \"__neg__\", \"__invert__\"]\n\tops += [\"__add__\", \"__sub__\", \"__mul__\", \"__floordiv__\", \"__truediv__\", \"__pow__\", \"__mod__\"]\n\tops += [\"__radd__\", \"__rsub__\", \"__rmul__\", \"__rdiv__\", \"__rtruediv__\", \"__rfloordiv__\", \"__rpow__\"]\n\tops += [\"__iadd__\", \"__isub__\", \"__imul__\", \"__itruediv__\",]\n\tops += [\"__lt__\", \"__gt__\", \"__ge__\", \"__le__\", \"__eq__\", \"__ne__\",]\n\tops += [\"lt\", \"lt_\", \"gt\", \"gt_\", \"ge\", \"ge_\", \"le\", \"le_\", \"eq\", \"eq_\", \"ne\", \"ne_\",]\n\tops += [\"__and__\", \"__or__\", \"__xor__\", \"__lshift__\", \"__rshift__\"]\n\tops += [\"__iand__\", \"__ior__\", \"__ixor__\", \"__ilshift__\", \"__irshift__\"]\n\tops += [\"abs\", \"abs_\", \"neg\", \"neg_\"]\n\tops += [\"add\", \"add_\", \"div\", \"div_\", \"mul\", \"mul_\", \"reciprocal\", \"reciprocal_\", \"remainder\", \"remainder_\", \"sub\", \"sub_\",]\n\tops += [\"addcdiv\", \"addcdiv_\", \"addcmul\", \"addcmul_\"]\n\tops += [\"exp\", \"exp_\", \"exp1m\", \"exp1m_\", \"log\", \"log_\", \"log10\", \"log10_\", \"log1p\", \"log1p_\", \"log2\", \"log2_\", \"pow\", \"pow_\", \"rsqrt\", \"rsqrt_\", \"sqrt\", \"sqrt_\",]\n\tops += [\"ceil\", \"ceil_\", \"clamp\", \"clamp_\", \"floor\", \"floor_\", \"fmod\", \"fmod_\", \"frac\", \"frac_\", \"round\", \"round_\", \"sign\", \"sign_\", \"trunc\", \"trunc_\"]\n\tops += [\"acos\", \"acos_\", \"asin\", \"asin_\", \"atan\", \"atan_\", \"atan2\", \"atan2_\", \"cos\", \"cos_\", \"cosh\", \"cosh_\", \"sin\", \"sin_\", \"sinh\", \"sinh_\", \"tan\", \"tan_\", \"sigmoid\", \"sigmoid_\", \"tanh\", \"tanh_\"]\n\tops += [\"digamma\", \"erf\", \"erf_\", \"erfc\", \"erfc_\", \"erfinv\", \"erfinv_\", \"lerp\", \"lerp_\", \"mvlgamma\",]\n\n\t@staticmethod\n\tdef foo(d):\n\t\treturn d['name'],d['type'],d['shape'],d['dtype']\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.dir = d.dir\n\t\tassert (d.dir in [\"fprop\", \"bprop\"])\n\t\tassert (op in Pointwise.ops)\n\n\t\t#Filter out all named parameters (kwargs).\n\t\t#This might require revisiting in future.\n\t\targs = list(filter(lambda x : x['name'] == \"\", args))\n\n\t\t#Filter out non tensors\n\t\targs = list(filter(lambda x : x['type'] == \"tensor\", args))\n\n\t\tif (len(args) == 0):\n\t\t\tself.shape = [(1,)]\n\t\t\tself.type = \"float32\" #FIX\n\n\t\telif (len(args) == 1):\n\t\t\tin0 = args[0]\n\t\t\t_,t0,s0,dt0 = Pointwise.foo(in0)\n\t\t\tassert (t0 == \"tensor\")\n\t\t\tself.shape = [s0,]\n\t\t\tself.type = dt0\n\n\t\telif (len(args) == 2):\n\t\t\tin0,in1 = args\n\t\t\t_,t0,s0,dt0 = Pointwise.foo(in0)\n\t\t\t_,t1,s1,dt1 = Pointwise.foo(in1)\n\t\t\tassert (t0 == t1 == \"tensor\")\n\t\t\tassert (dt0 == dt1)\n\t\t\tself.shape = [s0,s1]\n\t\t\tself.type = dt0\n\n\t\telif (len(args) == 3):\n\t\t\tin0,in1,in2 = args\n\t\t\t_,t0,s0,dt0 = Pointwise.foo(in0)\n\t\t\t_,t1,s1,dt1 = Pointwise.foo(in1)\n\t\t\t_,t2,s2,dt2 = Pointwise.foo(in2)\n\t\t\tassert (t0 == t1 == t2 == \"tensor\")\n\t\t\tassert (dt0 == dt1 == dt2)\n\t\t\tself.shape = [s0,s1,s2]\n\t\t\tself.type = dt0\n\t\telse:\n\t\t\tassert False\n\t\treturn\n\n\tdef params(self):\n\t\tp = OrderedDict([('T',self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\ttensor = self.shape\n\t\tt = self.type\n\n\t\tif (len(tensor) == 1):\n\t\t\telems = 2 * Utility.numElems(tensor[0])\n\t\telif (len(tensor) == 2):\n\t\t\tif (tensor[0] == tensor[1]):\t# same shape\n\t\t\t\telems = Utility.numElems(tensor[0])\n\t\t\t\tif self.dir == \"fprop\":\n\t\t\t\t\telems *= 3\n\t\t\t\telse:\n\t\t\t\t\tif (self.op_ in [\"add\", \"__add__\", \"sub\", \"__sub__\", \"__isub__\"]):\n\t\t\t\t\t\telems *= 2\n\t\t\t\t\telif (self.op_ in [\"__mul__\", \"__rmul__\", \"div\", \"__truediv__\"]):\n\t\t\t\t\t\telems *= 3\n\t\t\t\t\telse:\n\t\t\t\t\t\tassert False\n\t\t\telse:\t#check for broadcast conditions\n\t\t\t\tarray1 = np.empty(list(tensor[0]))\n\t\t\t\tarray2 = np.empty(list(tensor[1]))\n\t\t\t\ttry:\n\t\t\t\t\tout = np.broadcast(array1, array2).shape\n\t\t\t\texcept:\n\t\t\t\t\tassert False\n\n\t\t\t\telems = Utility.numElems(tensor[0])\n\t\t\t\telems += Utility.numElems(tensor[1])\n\t\t\t\telems += Utility.numElems(out)\n\t\t\t\t#TODO bprop\n\t\telif (len(tensor) == 3):\n\t\t\tif (tensor[0] == tensor[1] == tensor[2]):\t#same shape\n\t\t\t\telems = Utility.numElems(tensor[0])\n\t\t\t\telems *= 4\n\t\t\telse:\n\t\t\t\tassert False\n\t\telse:\n\t\t\tassert False\n\n\t\treturn elems\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\n\tdef flops(self):\n\t\t# Note: some cases may still be missing.\n\n\t\tf = 0\n\t\tif self.op_ in [\"__abs__\", \"__neg__\", \"__add__\", \"__sub__\", \"__mul__\",\n\t\t\t\t\t\"__radd__\", \"__rmul__\", \"__iadd__\", \"__isub__\", \"__imul__\", \"__itruediv__\",\n\t\t\t\t\t\"abs\", \"abs_\", \"neg\", \"neg_\", \"add\", \"add_\", \"div\", \"div_\", \"mul\", \"mul_\",\n\t\t\t\t\t\"sub\", \"sub_\", \"exp\", \"exp_\", \"sign\", \"sign_\", \"trunc\", \"trunc_\",\n\t\t\t\t\t\"sin\", \"sin_\", \"cos\", \"cos_\", \"sinh\", \"sinh_\", \"cosh\", \"cosh_\",\n\t\t\t\t\t\"sqrt\", \"sqrt_\", \"rsqrt\", \"rsqrt_\", \"__lt__\", \"__gt__\", \"__ge__\", \"__le__\",\n\t\t\t\t\t\"__eq__\", \"__ne__\", \"lt\", \"lt_\", \"gt\", \"gt_\", \"ge\", \"ge_\", \"le\", \"le_\",\n\t\t\t\t\t\"eq\", \"eq_\", \"ne\", \"ne_\", \"ceil\", \"ceil_\", \"clamp\", \"clamp_\", \"floor\", \"floor_\",\n\t\t\t\t\t\"round\", \"sign\", \"sign_\", \"trunc\", \"trunc_\"]:\n\t\t\t# We're counting only one operand, not two (2 operands, 1 op)\n\t\t\tf = self.elems() / 2\n\t\telif self.op_ in [\"fmod\", \"fmod_\"]:\n\t\t\tf = self.elems()\n\t\telif self.op_ in [\"tanh\", \"tanh_\", \"sigmoid\", \"sigmoid_\", \"log\", \"log_\", \"log2\",\n\t\t\t \"log2_\", \"log10\", \"log10_\"]:\n\t\t\tf = self.elems() * 2\n\t\telif self.op_ in [\"asin\", \"asin_\", \"acos\", \"acos_\", \"atan\", \"atan_\"]:\n\t\t\t# no intrinsic, hence slow execution\n\t\t\t# surprisingly, asin/acos and atan were all the same (via nvprof measurement)\n\t\t\tf = self.elems() * 10\n\n\t\treturn f\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/pooling.py",
    "content": "from .collections import OrderedDict\nfrom .utility import Utility\n\n# Work in progress.\n\n#poolFuncs = [\"max_pool2d_with_indices_forward\", \"max_pool2d_with_indices\"]\nclass MaxPool2d(object):\n\n\tdef parse(marker):\n\n\t\tdef convert2Tuple(arg):\n\t\t\tassert (arg['type'] in [\"int\", \"tuple\"])\n\t\t\tif arg['type'] == \"int\":\n\t\t\t\treturn (arg['value'], arg['value'])\n\t\t\telse:\n\t\t\t\treturn arg['value']\n\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"max_pool2d\")\n\t\tassert (len(args) >= 2)\n\n\t\t#input\n\t\tassert (args[0]['name'] == \"\")\n\t\tinp = args[0]\n\t\tassert (inp['type'] == \"tensor\")\n\t\ti = inp['shape']\n\t\tt = inp['dtype']\n\t\tassert (len(i) == 4) #nchw tensor\n\n\t\t#kernel\n\t\tif (args[1]['name'] == \"\"):\n\t\t\tk = args[1]\n\t\telse:\n\t\t\tk = list(filter(lambda x : x['name'] == \"kernel_size\", args))[0]\n\t\tk = convert2Tuple(k)\n\n\t\t#stride\n\t\ts = k #default value\n\t\tif ((len(args) >= 3) and args[2] == \"\"):\n\t\t\ts = args[2]\n\t\t\ts = convert2Tuple(s)\n\t\telif any(x['name'] == \"stride\" for x in args):\n\t\t\ts = list(filter(lambda x : x['name'] == \"stride\", args))[0]\n\t\t\ts = convert2Tuple(s)\n\n\t\t#padding\n\t\tp = (0,0)\n\t\tif ((len(args) >= 4) and args[3] == \"\"):\n\t\t\tp = args[3]\n\t\t\tp = convert2Tuple(p)\n\t\telif any(x['name'] == \"padding\" for x in args):\n\t\t\tp = list(filter(lambda x : x['name'] == \"padding\", args))[0]\n\t\t\tp = convert2Tuple(p)\n\t\t\n\t\tparams = OrderedDict([('T', i), ('K', k), ('s',s), ('p',p), ('type', t)])\n\t\treturn params\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/prof.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nThis script reads the output (Python dictionary) created by parse.py.\nFor every kernel (line) in the input it determines\n\tmodule / class name e.g. torch.nn.functional\n\toperator name e.g. linear\n\tkernel parameters e.g. GEMM M, N, K, datatype\n\tbytes\n\tflops\n\ttensor core usage\n\tdirection (fprop, bprop)\n\tand other things. Please see the tool usage.\n\"\"\"\n\nfrom .usage import parseArgs\nfrom .output import Output\nfrom .utility import Utility\nfrom .pointwise import Pointwise\nfrom .convert import Convert\nfrom .blas import *\nfrom .embedding import Embedding\nfrom .reduction import *\nfrom .dropout import Dropout\nfrom .softmax import *\n#from pooling import * # work in progress\nfrom .linear import Linear\nfrom .optim import Adam\nfrom .misc import *\nfrom .conv import Conv\nfrom .activation import Activation\nfrom .index_slice_join_mutate import Cat, Reshape, MaskedScatter, Gather, Nonzero, IndexSelect, MaskedSelect\nfrom .recurrentCell import RNNCell\nfrom .normalization import BatchNorm\nfrom .randomSample import RandPerm\nfrom .loss import MSELoss\nfrom .data import Data\n\ndef findFpropKernel(seq):\n\t#Find the last fprop kernel with the same seqId\n\t#First look at seqId and then at altSeqId\n\tfor idx in reversed(range(len(kernels))):\n\t\tk = kernels[idx]\n\t\tif (seq in k['seqId']) and (k['dir'] == \"fprop\"):\n\t\t\treturn idx\n\n\tfor idx in reversed(range(len(kernels))):\n\t\tk = kernels[idx]\n\t\tif (seq in k['altSeqId']) and (k['dir'] == \"fprop\"):\n\t\t\treturn idx\n\n\treturn -1\n\t#print(\"Error: seqId {} not found.\".format(seq), file=sys.stderr)\n\t#assert False\n\ndef foo(mod, op, d):\n\tif (op[0] == \"linear\"):\n\t\txx = Linear(d)\n\n\t# rnncell, lstmcell, grucell\n\telif (mod[0] in[\"LSTMCell\", \"GRUCell\"]) and (op[0] == \"forward\"):\n\t\txx = RNNCell(d)\n\n\telif op[0] in [\"conv1d\", \"conv2d\",]:\n\t\txx = Conv(d)\n\n\telif (op[0] in Pointwise.ops):\n\t\txx = Pointwise(d)\n\n\telif (op[0] in Convert.ops):\n\t\txx = Convert(d)\n\n\telif op[0] in [\"__matmul__\", \"matmul\"]:\n\t\txx = Matmul(d)\n\n\telif op[0] == \"embedding\":\n\t\txx = Embedding(d)\n\n\t#reduction\n\telif op[0] == \"sum\":\n\t\txx = Sum(d)\n\n\telif op[0] == \"mean\":\n\t\txx = Mean(d)\n\n\telif op[0] == \"norm\":\n\t\txx = Norm(d)\n\n\telif op[0] == \"dropout\":\n\t\txx = Dropout(d)\n\n\t#Index, Slice, Join, Mutate\n\telif (op[0] == \"cat\"):\n\t\txx = Cat(d)\n\n\telif (op[0] == \"reshape\"):\n\t\txx = Reshape(d)\n\n\telif (op[0] == \"masked_scatter_\"):\n\t\txx = MaskedScatter(d)\n\n\telif (op[0] == \"gather\"):\n\t\txx = Gather(d)\n\n\telif (op[0] == \"nonzero\"):\n\t\txx = Nonzero(d)\n\n\telif (op[0] == \"index_select\"):\n\t\txx = IndexSelect(d)\n\n\telif (op[0] == \"masked_select\"):\n\t\txx = MaskedSelect(d)\n\n\t#blas\n\telif op[0] in [\"addmm\", \"addmm_\"]:\n\t\txx = Addmm(d)\n\n\telif op[0] == \"mm\":\n\t\txx = Mm(d)\n\n\telif op[0] == \"bmm\":\n\t\txx = Bmm(d)\n\n\t#softmax\n\telif op[0] == \"softmax\":\n\t\txx = Softmax(d)\n\n\telif op[0] == \"log_softmax\":\n\t\txx = LogSoftmax(d)\n\n\t#loss\n\telif op[0] == \"mse_loss\":\n\t\txx = MSELoss(d)\n\n\t#optimizers\n\telif op[0] == \"adam\":\n\t\txx = Adam(d)\n\n\t#normalization\n\telif op[0] == \"batch_norm\":\n\t\txx = BatchNorm(d)\n\n\t#random\n\telif op[0] == \"randperm\":\n\t\txx = RandPerm(d)\n\n\t#misc\n\telif op[0] == \"copy_\":\n\t\txx = Copy(d)\n\n\telif op[0] == \"clone\":\n\t\txx = Clone(d)\n\n\telif op[0] == \"contiguous\":\n\t\txx = Contiguous(d)\n\n\telif op[0] == \"any\":\n\t\txx = Any(d)\n\n\telif (op[0] in Activation.ops):\n\t\txx = Activation(d)\n\n\telif op[0] == \"to\":\n\t\txx = Convert(d)\n\n\telse:\n\t\txx = Foo(d)\n\n\treturn xx\n\ndef main():\n\t#Read cmd line arguments\n\tcmdArgs = parseArgs()\n\n\toutput = Output(cmdArgs)\n\toutput.header()\n\n\tidx = -1\n\t#Read in all the kernel info\n\tfor line in cmdArgs.file:\n\t\tidx += 1\n\t\tkernel = eval(line)\n\t\tassert(kernel)\n\t\tkernels.append(kernel)\n\n\t\tk = kernel\n\t\td = Data(k)\n\n\t\tmod = k['mod']\n\t\top = k['op']\n\n\t\tflops = 0\n\t\tparams = {\"na\":\"na\"}\n\t\ttc = \"na\"\n\t\tbytes = 0\n\n\t\tif (d.dir == \"bprop\"):\n\t\t\td.seqMarker = k['seqMarker']\n\t\t\tseq = k['seqId']\n\t\t\tif len(seq) > 1:\n\t\t\t\tpass\n\t\t\tseq = k['seqId'][:1]\n\t\t\tassert (len(seq) == 1), seq\n\t\t\t#assert (seq[0] != 0)\n\t\t\tassert (len(d.seqMarker) > 0)\n\t\t\t#If there is no useful marker associated, use the\n\t\t\t#sequence number to find the kernel from fprop\n\t\t\tif len(d.argMarker) == 0:\n\t\t\t\tindex = findFpropKernel(seq[0])\n\t\t\t\tif index >= 0:\n\t\t\t\t\td.argMarker = kernels[index]['marker']\n\t\t\t\t\td.modMarker = kernels[index]['reprMarkers']\n\t\t\t\t\tmod = kernels[index]['mod']\n\t\t\t\t\top = kernels[index]['op']\n\n\t\t\t\t\td.layer = kernels[index]['layer']\n\t\t\t\t\td.trace = kernels[index]['trace']\n\n\t\t# Check if marker has our annotations\n\t\tif len(d.argMarker) and Utility.hasNVTX(d.argMarker[0]):\n\n\t\t\txx = foo(mod, op, d)\n\n\t\t\tbytes = xx.bytes()\n\t\t\tflops = xx.flops()\n\t\t\top = xx.op()\n\t\t\tparams = xx.params()\n\t\t\ttc = xx.tc()\n\n\t\tif type(op) is list:\n\t\t\tif len(op):\n\t\t\t\top = op[0]\n\t\t\telse:\n\t\t\t\top = \"\"\n\n\t\tif type(mod) is list:\n\t\t\tif len(mod):\n\t\t\t\tmod = mod[0]\n\t\t\telse:\n\t\t\t\tmod = \"\"\n\n\t\td.index = idx+1\n\n\t\t# The following 8 come from operator class functions.\n\t\td.setParams(params)\n\t\td.tc = tc\n\t\td.flops = flops\n\t\td.bytes = bytes\n\t\td.mod = mod\n\t\td.op = op\n\n\t\toutput.data(d)\n\nkernels = []\nif __name__ == '__main__':\n\tmain()\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/randomSample.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass RandPerm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch\")\n\t\tassert (op == \"randperm\")\n\t\tassert (len(args) == 1)\n\t\tn = args[0]\n\t\tassert n['type'] == \"int\"\n\t\tself.n = n['value']\n\n\tdef params(self):\n\t\tp = OrderedDict([('N', self.n)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\treturn self.n * Utility.typeToBytes(\"int64\")\n\n\tdef flops(self):\n\t\t# Depends on RNG but this is probably a reasonable assumption.\n\t\treturn self.n * 3\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/recurrentCell.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\ndef hasTileSize(name):\n\tif (\"sgemm\" in name) or (\"884gemm\" in name) or (\"hgemm\" in name):\n\t\treturn True\n\telse:\n\t\treturn False\n\ndef ctaTile(name):\n\tname = name.split(\"_\")\n\tname = list(filter(lambda x : \"x\" in x, name))\n\tname = list(filter(lambda x : \"slice\" not in x, name))\n\tassert(len(name) == 1)\n\tname = name[0].split(\"x\")\n\tassert(len(name) == 2)\n\tname = list(map(int, name))\n\treturn name[0], name[1]\n\nclass RNNCell(OperatorLayerBase):\n\t\"\"\"\n\tThis class supports RNNCell, LSTMCell and GRUCell.\n\t\"\"\"\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tself.name = d.name\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\t\tself.grid = d.grid\n\n\t\tassert (op == \"forward\")\n\t\tassert (mod in [\"LSTMCell\", \"GRUCell\", \"RNNCell\"])\n\t\tassert (len(args) in [2,3])\n\n\t\tx,h = args[0],args[1]\n\t\tb1,ii = x['shape']\n\t\tb2,hh = h['shape']\n\t\tassert b1 == b2\n\t\tassert x['dtype'] == h['dtype']\n\t\tt = x['dtype']\n\n\t\tself.cell = mod\n\t\tself.inp = ii\n\t\tself.hid = hh\n\t\tself.b = b1\n\t\tself.type = t\n\n\t\tself.multiple = 1\n\t\tif self.cell == \"LSTMCell\":\n\t\t\tself.multiple = 4\n\t\telif self.cell == \"GRUCell\":\n\t\t\tself.multiple = 3\n\n\t\tself.gemm = None\n\t\tself.m = None\n\t\tself.n = None\n\t\tself.k = None\n\t\tself.elems = 0\n\n\t\tself.bar()\n\t\t\n\tdef params(self):\n\t\tif self.gemm is None:\n\t\t\tp = OrderedDict([('cell', self.cell), ('X', self.inp), ('H', self.hid), ('B', self.b), ('type', self.type)])\n\t\telse:\n\t\t\tassert self.m is not None\n\t\t\tassert self.n is not None\n\t\t\tassert self.k is not None\n\t\t\tp = OrderedDict([('gemm', self.gemm), ('M', self.m), ('N', self.n), ('K', self.k), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\tif \"gemm\" in self.name:\n\t\t\treturn 1 if \"884gemm\" in self.name else 0\n\t\telse:\n\t\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef bytes(self):\n\t\tif self.gemm is not None:\n\t\t\tm, n, k, t = self.m, self.n, self.k, self.type\n\t\t\tb = (m*k + k*n + m*n) * Utility.typeToBytes(t)\n\t\telif self.elems != 0:\n\t\t\tb = self.elems * Utility.typeToBytes(self.type)\n\t\telse:\n\t\t\tb = 0\n\t\treturn b\n\n\tdef flops(self):\n\t\tif self.gemm is not None:\n\t\t\tm, n, k = self.m, self.n, self.k\n\t\t\tf = 2*m*n*k\n\t\telif self.elems != 0:\n\t\t\tf = 0 #TODO\n\t\telse:\n\t\t\tf = 0\n\t\treturn f\n\n\tdef bar(self):\n\t\tcell = self.cell\n\t\tX = self.inp\n\t\tH = self.hid\n\t\tB = self.b\n\t\tt = self.type\n\t\tsubseqId = self.sub\n\t\tdirec = self.dir\n\t\tname = self.name\n\t\tgrid = self.grid\n\t\tmultiple = self.multiple\n\n\t\tif direc == \"fprop\":\n\t\t\tsubseqId = subseqId % 3\n\t\t\tif subseqId == 0: #layer gemm\n\t\t\t\tself.gemm = \"layer\"\n\t\t\t\tself.m = multiple*H\n\t\t\t\tself.n = B\n\t\t\t\tself.k = X\n\t\t\telif subseqId == 1: #recurrent gemm\n\t\t\t\tself.gemm = \"recur\"\n\t\t\t\tself.m = multiple*H\n\t\t\t\tself.n = B\n\t\t\t\tself.k = H\n\t\t\telse:\n\t\t\t\tlayerGemmElems = multiple*H*B\n\t\t\t\trecurGemmElems = multiple*H*B\n\t\t\t\tcElems = H*B\n\t\t\t\thElems = H*B\n\t\t\t\ttotElems = layerGemmElems + recurGemmElems + 2*cElems + hElems\n\t\t\t\tself.elems = totElems\n\n\t\telse:\n\t\t\tif (\"gemm\" in name) and hasTileSize(name):\t#gemm\n\t\t\t\t#Get cta tile size\n\t\t\t\ttileX, tileY = ctaTile(name)\n\t\t\t\t#Get grid dimensions\n\t\t\t\tgrid = grid.split(\",\")\n\t\t\t\tgridX,gridY,gridZ = map(lambda x : int(x), grid)\n\n\t\t\t\tgemmM = tileX * gridX\n\t\t\t\tgemmN = tileY * gridY\n\n\t\t\t\tif name[-3:] == \"_nn\": # dgrad\n\t\t\t\t\tif (gemmM == H):\t# recurrent dgrad\n\t\t\t\t\t\t#Ideally gemmN = B, but we have a limited set of tile sizes.\n\t\t\t\t\t\tgemmN = B\n\t\t\t\t\t\tgemmK = multiple*H\n\n\t\t\t\t\t\tself.gemm = \"recur\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telif (gemmM == X):\t# layer dgrad\n\t\t\t\t\t\t#assert(gemmN % B == 0)\n\t\t\t\t\t\tgemmK = multiple*H\n\n\t\t\t\t\t\tself.gemm = \"layer\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telse:\n\t\t\t\t\t\tpass\n\n\t\t\t\telif name[-3:] == \"_nt\": #wgrad\n\t\t\t\t\tif (gemmM == H):\t#recurrent wgrad\n\t\t\t\t\t\tassert (gemmN == multiple*H)\n\t\t\t\t\t\tgemmK = B\n\n\t\t\t\t\t\tself.gemm = \"recur\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telif (gemmM == X):\t#layer wgrad\n\t\t\t\t\t\tassert (gemmN == multiple*H)\n\t\t\t\t\t\tgemmK = B\n\n\t\t\t\t\t\tself.gemm = \"layer\"\n\t\t\t\t\t\tself.m = gemmM\n\t\t\t\t\t\tself.n = gemmN\n\t\t\t\t\t\tself.k = gemmK\n\n\t\t\t\t\telse:\n\t\t\t\t\t\tpass\n\t\t\t\telse:\n\t\t\t\t\tpass\n\t\t\telse:\n\t\t\t\tpass\n\n\t\treturn\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/reduction.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Mean(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"mean\")\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) <= 2)\n\t\ti = args[0]\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\t\tself.dir = d.dir\n\t\tself.sub = d.sub\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\tif self.sub == 0:\n\t\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\t\telse:\n\t\t\treturn 0\n\n\tdef flops(self):\n\t\tif self.sub == 0:\n\t\t\treturn self.elems() + 1\n\t\telse:\n\t\t\treturn 0\n\nclass Sum(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"sum\")\n\t\tassert (len(args) >= 1)\n\n\t\t#Get input\n\t\tif (args[0]['name'] == \"\"):\n\t\t\ti = args[0]\n\t\telse:\n\t\t\ti = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Note: This is incorrect, need to calculate actual flops (say via nvprof)\n\t\treturn self.elems()\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\nclass Norm(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod in [\"torch\", \"Tensor\"])\n\t\tassert (op == \"norm\")\n\t\t#assert (len(args) == 1)\n\t\ti = args[0]\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef bytes(self):\n\t\treturn self.elems() * Utility.typeToBytes(self.type)\n\n\tdef flops(self):\n\t\t# square and add plus sqrt\n\t\treturn 2 * self.elems() + 1\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/softmax.py",
    "content": "from collections import OrderedDict\nfrom .utility import Utility\nfrom .base import OperatorLayerBase\n\nclass Softmax(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"softmax\")\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) <= 2)\n\t\tself.shape = args[0]['shape']\n\t\tself.type = args[0]['dtype']\n\t\tself.dir = d.dir\n\n\t\treturn\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Note: exp, sum-reduce, divide\n\t\t#flops = elems * 3\n\t\treturn 0\n\n\tdef bytes(self):\n\t\tb = self.elems() * Utility.typeToBytes(self.type)\n\t\tb *= 3 if self.dir == \"fprop\" else 5 #verify\n\t\treturn b\n\nclass LogSoftmax(OperatorLayerBase):\n\n\tdef __init__(self, d):\n\t\tmarker = eval(d.argMarker[0])\n\t\tmod = marker['mod']\n\t\top = marker['op']\n\t\targs = marker['args']\n\n\t\tself.marker = marker\n\t\tself.mod_ = mod\n\t\tself.op_ = op\n\t\tself.args = args\n\n\t\tassert (mod == \"torch.nn.functional\")\n\t\tassert (op == \"log_softmax\")\n\n\t\t#Filter out named parameters\n\t\targs = list(filter(lambda x : x['name'] == '', args))\n\n\t\tassert (len(args) <= 2)\n\n\t\t#Get input\n\t\tif (args[0]['name'] == \"\"):\n\t\t\ti = args[0]\n\t\telse:\n\t\t\ti = list(filter(lambda x : x['name'] == \"input\", args))[0]\n\n\t\tt = i['dtype']\n\n\t\tself.shape = i['shape']\n\t\tself.type = i['dtype']\n\t\tself.dir = d.dir\n\t\treturn\n\n\tdef op(self):\n\t\treturn self.op_\n\n\tdef mod(self):\n\t\treturn self.mod_\n\n\tdef tc(self):\n\t\treturn \"-\"\n\n\tdef params(self):\n\t\tp = OrderedDict([('T', self.shape), ('type', self.type)])\n\t\treturn p\n\n\tdef elems(self):\n\t\treturn Utility.numElems(self.shape)\n\n\tdef flops(self):\n\t\t# Note: exp, sum-reduce, divide, log\n\t\t#flops = elems * 4\n\t\treturn 0\n\n\tdef bytes(self):\n\t\tb = self.elems() * Utility.typeToBytes(self.type)\n\t\tb *= 3 if self.dir == \"fprop\" else 5 #verify\n\t\treturn b\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/usage.py",
    "content": "import sys\nimport argparse\n\ndef parseArgs():\n\t\"\"\"\n\tPrint usage and parse arguments.\n\t\"\"\"\n\n\tdef check_cols(value):\n\t\tvalid = [\"idx\", \"seq\", \"altseq\", \"tid\", \"layer\", \"trace\", \"dir\", \"sub\", \"mod\", \"op\", \"kernel\", \"params\", \"sil\", \"tc\", \"device\", \"stream\", \"grid\", \"block\", \"flops\", \"bytes\"]\n\t\tcols = value.split(\",\")\n\t\tfor col in cols:\n\t\t\tif col not in valid:\n\t\t\t\traise argparse.ArgumentTypeError(\"{} is not a valid column name. Valid column names are {}.\".format(col, \",\".join(valid)))\n\t\treturn cols\n\n\tdef openFile(f):\n\t\ttry:\n\t\t\td = open(f, \"r\")\n\t\t\treturn d\n\t\texcept IOError:\n\t\t\tprint(\"Error opening file {}. Exiting.\".format(f), file=sys.stderr)\n\t\t\tsys.exit(1)\n\n\tparser = argparse.ArgumentParser(prog=sys.argv[0], description=\"PyTorch Profiler\", formatter_class=argparse.RawTextHelpFormatter)\n\tparser.add_argument(\"file\",\n\t\tnargs='?',\n\t\ttype=str,\n\t\tdefault=None,\n\t\thelp=\"Output of parse.py (Python dictionary).\")\n\n\tparser.add_argument(\"-c\",\n\t\ttype=check_cols,\n\t\tdefault=\"idx,dir,sub,mod,op,kernel,params,sil\",\n\t\thelp='''Comma seperated names of columns to print.\nidx:      Index\nseq:      PyTorch Sequence Id\naltseq:   PyTorch Alternate Sequence Id\ntid:      Thread Id\nlayer:    User annotated NVTX string (can be nested)\ntrace:    Function Call Trace\ndir:      Direction\nsub:      Sub Sequence Id\nmod:      Module\nop:       Operattion\nkernel:   Kernel Name\nparams:   Parameters\nsil:      Silicon Time (in ns)\ntc:       Tensor Core Usage\ndevice:   GPU Device Id\nstream:   Stream Id\ngrid:     Grid Dimensions\nblock:    Block Dimensions\nflops:    Floating point ops (FMA = 2 FLOPs)\nbytes:    Number of bytes in and out of DRAM\ne.g. -c idx,kernel,sil''')\n\n\tgroup = parser.add_mutually_exclusive_group()\n\tgroup.add_argument(\"--csv\",\n\t\taction=\"store_true\",\n\t\tdefault=False,\n\t\thelp=\"Print a CSV output.\")\n\tgroup.add_argument(\"-w\",\n\t\ttype=int,\n\t\tdefault=0,\n\t\thelp=\"Width of columnated output.\")\n\n\targs = parser.parse_args()\n\tif args.file is None:\n\t\targs.file = sys.stdin\n\telse:\n\t\targs.file = openFile(args.file)\n\treturn args\n"
  },
  {
    "path": "KoSimCSE/apex/pyprof/prof/utility.py",
    "content": "from functools import reduce\n\nclass Utility(object):\n\n\t@staticmethod\n\tdef numElems(shape):\n\t\tassert (type(shape) == tuple)\n\t\treturn reduce(lambda x,y: x*y, shape, 1)\n\n\t@staticmethod\n\tdef typeToBytes(t):\n\t\tif (t in [\"uint8\", \"int8\", \"byte\", \"char\", \"bool\"]):\n\t\t\treturn 1\n\t\telif (t in [\"float16\", \"half\", \"int16\", \"short\"]):\n\t\t\treturn 2\n\t\telif (t in [\"float32\", \"float\", \"int32\", \"int\"]):\n\t\t\treturn 4\n\t\telif (t in [\"int64\", \"long\", \"float64\", \"double\"]):\n\t\t\treturn 8\n\t\tassert False\n\n\t@staticmethod\n\tdef typeToString(t):\n\t\tif (t in [\"uint8\", \"byte\", \"char\",]):\n\t\t\treturn \"uint8\"\n\t\telif (t in [\"int8\",]):\n\t\t\treturn \"int8\"\n\t\telif (t in [\"int16\", \"short\",]):\n\t\t\treturn \"int16\"\n\t\telif (t in [\"float16\", \"half\"]):\n\t\t\treturn \"fp16\"\n\t\telif (t in [\"float32\", \"float\"]):\n\t\t\treturn \"fp32\"\n\t\telif (t in [\"int32\", \"int\",]):\n\t\t\treturn \"int32\"\n\t\telif (t in [\"int64\", \"long\"]):\n\t\t\treturn \"int64\"\n\t\telif (t in [\"float64\", \"double\",]):\n\t\t\treturn \"fp64\"\n\t\telif (t in [\"bool\",]):\n\t\t\treturn \"bool\"\n\t\tassert False\n\n\t@staticmethod\n\tdef hasNVTX(marker):\n\t\tif type(marker) is str:\n\t\t\ttry:\n\t\t\t\tmarker = eval(marker)\n\t\t\texcept:\n\t\t\t\treturn False\n\n\t\tif type(marker) is dict:\n\t\t\tkeys  = marker.keys()\n\t\t\treturn (\"mod\" in keys) and (\"op\" in keys) and (\"args\" in keys)\n\t\telse:\n\t\t\treturn False\n\n\t@staticmethod\n\tdef isscalar(t):\n\t\treturn (t in [\"float\", \"int\"])\n"
  },
  {
    "path": "KoSimCSE/apex/reparameterization/README.md",
    "content": "Under construction...\n"
  },
  {
    "path": "KoSimCSE/apex/reparameterization/__init__.py",
    "content": "from .weight_norm import WeightNorm\nfrom .reparameterization import Reparameterization\n\ndef apply_weight_norm(module, name='', dim=0, hook_child=True):\n    r\"\"\"\n    Applies weight normalization to a parameter in the given module.\n    If no parameter is provided, applies weight normalization to all\n    parameters in model (except 1-d vectors and scalars).\n\n    .. math::\n         \\mathbf{w} = g \\dfrac{\\mathbf{v}}{\\|\\mathbf{v}\\|}\n\n    Weight normalization is a reparameterization that decouples the magnitude\n    of a weight tensor from its direction. This replaces the parameter specified\n    by `name` (e.g. \"weight\") with two parameters: one specifying the magnitude\n    (e.g. \"weight_g\") and one specifying the direction (e.g. \"weight_v\").\n    Weight normalization is implemented via a hook that recomputes the weight\n    tensor from the magnitude and direction before every :meth:`~Module.forward`\n    call.\n\n    By default, with `dim=0`, the norm is computed independently per output\n    channel/plane. To compute a norm over the entire weight tensor, use\n    `dim=None`.\n\n    See https://arxiv.org/abs/1602.07868\n\n    Args:\n        module (nn.Module): containing module\n        name (str, optional): name of weight parameter\n        dim (int, optional): dimension over which to compute the norm\n        hook_child (boolean, optional): adds reparameterization hook to direct parent of the \n            parameters. If False, it's added to `module` instead. Default: True\n\n    Returns:\n        The original module with the weight norm hook\n\n    Example::\n\n        >>> m = apply_weight_norm(nn.Linear(20, 40), name='weight')\n        Linear (20 -> 40)\n        >>> m.weight_g.size()\n        torch.Size([40, 1])\n        >>> m.weight_v.size()\n        torch.Size([40, 20])\n\n    \"\"\"\n    return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child,\n                                    name=name, dim=dim)\n\ndef remove_weight_norm(module, name='', remove_all=False):\n    \"\"\"\n    Removes the weight normalization reparameterization of a parameter from a module.\n    If no parameter is supplied then all weight norm parameterizations are removed.\n    Args:\n        module (nn.Module): containing module\n        name (str, optional): name of weight parameter\n    Example:\n        >>> m = apply_weight_norm(nn.Linear(20, 40))\n        >>> remove_weight_norm(m)\n    \"\"\"\n    return remove_reparameterization(module, reparameterization=WeightNorm,\n                                    name=name, remove_all=remove_all)\n\ndef apply_reparameterization(module, reparameterization=None, name='', dim=0, hook_child=True):\n    \"\"\"\n    Applies a given weight reparameterization (such as weight normalization) to\n    a parameter in the given module. If no parameter is given, applies the reparameterization\n    to all parameters in model (except 1-d vectors and scalars).\n\n    Args:\n        module (nn.Module): containing module\n        reparameterization (Reparameterization): reparamaterization class to apply\n        name (str, optional): name of weight parameter\n        dim (int, optional): dimension over which to perform reparameterization op\n        hook_child (boolean, optional): adds reparameterization hook to direct parent of the \n            parameters. If False, it's added to `module` instead. Default: True\n\n    Returns:\n        The original module with the reparameterization hook\n\n    Example::\n\n        >>> m = apply_reparameterization(nn.Linear(20, 40), WeightNorm)\n        Linear (20 -> 40)\n\n    \"\"\"\n    assert reparameterization is not None\n    if name != '':\n        Reparameterization.apply(module, name, dim, reparameterization, hook_child)\n    else:\n        names = list(module.state_dict().keys())\n        for name in names:\n            apply_reparameterization(module, reparameterization, name, dim, hook_child)\n    return module\n\ndef remove_reparameterization(module, reparameterization=Reparameterization,\n                                name='', remove_all=False):\n    \"\"\"\n    Removes the given reparameterization of a parameter from a module.\n    If no parameter is supplied then all reparameterizations are removed.\n    Args:\n        module (nn.Module): containing module\n        reparameterization (Reparameterization): reparamaterization class to apply\n        name (str, optional): name of weight parameter\n        remove_all (bool, optional): if True, remove all reparamaterizations of given type. Default: False\n    Example:\n        >>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)\n        >>> remove_reparameterization(m)\n    \"\"\"\n    if name != '' or remove_all:\n        to_remove = []\n        for k, hook in module._forward_pre_hooks.items():\n            if isinstance(hook, reparameterization) and (hook.name == name or remove_all):\n                hook.remove(module)\n                to_remove.append(k)\n        if len(to_remove) > 0:\n            for k in to_remove:\n                del module._forward_pre_hooks[k]\n            return module\n        if not remove_all:\n            raise ValueError(\"reparameterization of '{}' not found in {}\"\n                             .format(name, module))\n    else:\n        modules = [module]+[x for x in module.modules()]\n        for m in modules:\n            remove_reparameterization(m, reparameterization=reparameterization, remove_all=True)\n        return module\n"
  },
  {
    "path": "KoSimCSE/apex/reparameterization/reparameterization.py",
    "content": "import torch\nfrom torch.nn.parameter import Parameter\nimport sys\nclass Reparameterization(object):\n    \"\"\"\n    Class interface for performing weight reparameterizations\n    Arguments:\n        name (str): name of weight parameter\n        dim (int): dimension over which to compute the norm\n        module (nn.Module): parent module to which param `name` is registered to\n        retain_forward (bool, optional): if False deletes weight on call to \n            module.backward. Used to avoid memory leaks with DataParallel Default: True\n    Attributes:\n        reparameterization_names (list, str): contains names of all parameters \n            needed to compute reparameterization.\n        backward_hook_key (int): torch.utils.hooks.RemovableHandle.id for hook used in module backward pass.\n    \"\"\"\n\n    def __init__(self, name, dim, module, retain_forward=True):\n        self.name = name\n        self.dim = dim\n        self.evaluated = False\n        self.retain_forward = retain_forward\n        self.reparameterization_names = []\n        self.backward_hook_key = None\n        self.module = module\n\n    def compute_weight(self, module=None, name=None):\n        \"\"\"\n        Computes reparameterized weight value to assign value to module attribute\n        with name `name`.\n        See WeightNorm class for example.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n        Returns:\n            w (Tensor): Tensor object containing value of reparameterized weight\n        \"\"\"\n        raise NotImplementedError\n\n    def reparameterize(self, name, weight, dim):\n        \"\"\"\n        Creates Parameters to be used for reparameterization and creates names that\n        for attributes for the module these Parameters will correspond to.\n        The parameters will be registered according to the names provided.\n        See WeightNorm class for example.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n            name (str, optional): name of weight parameter\n            dim (int, optional): dimension over which to compute parameterization\n        Returns:\n            names (list, str): names of Parameters to be used for reparameterization\n            params (list, Parameter): Parameters to be used for reparameterization\n        \"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def apply(module, name, dim, reparameterization=None, hook_child=True):\n        \"\"\"\n        Applies reparametrization to module's `name` parameter and modifies instance attributes as appropriate.\n        `hook_child` adds reparameterization hook to direct parent of the parameters. If False, it's added to `module` instead.\n        \"\"\"\n        if reparameterization is None:\n            reparameterization = Reparameterization\n        module2use, name2use = Reparameterization.get_module_and_name(module, name)\n        # does not work on sparse\n        if name2use is None or isinstance(module2use, (torch.nn.Embedding, torch.nn.EmbeddingBag)):\n            return\n\n        if hook_child:\n            fn = reparameterization(name2use, dim, module2use)\n        else:\n            fn = reparameterization(name, dim, module)\n\n        weight = getattr(module2use, name2use)\n        if weight.dim() <= 1:\n            return\n\n        # remove weight from parameter list\n        del module2use._parameters[name2use]\n\n        # add parameters of reparameterization of parameter to module\n        names, params = fn.reparameterize(name2use, weight, dim)\n        for n, p in zip(names, params):\n            module2use.register_parameter(n, p)\n\n        # add parameters to reparameterization so they can be removed later\n        fn.reparameterization_names = names\n\n        setattr(module2use, name2use, None)\n\n        hook_module = module2use\n        if not hook_child:\n            hook_module = module\n        # recompute weight before every forward()\n        hook_module.register_forward_pre_hook(fn)\n\n        # remove weight during backward\n        handle = hook_module.register_backward_hook(fn.backward_hook)\n        # get hook key so we can delete it later\n        fn.backward_hook_key = handle.id\n\n        return fn\n\n    @staticmethod\n    def get_module_and_name(module, name):\n        \"\"\"\n        recursively fetches (possible) child module and name of weight to be reparameterized\n        \"\"\"\n        name2use = None\n        module2use = None\n        names = name.split('.')\n        if len(names) == 1 and names[0] != '':\n            name2use = names[0]\n            module2use = module\n        elif len(names) > 1:\n            module2use = module\n            name2use = names[0]\n            for i in range(len(names)-1):\n                module2use = getattr(module2use, name2use)\n                name2use = names[i+1]\n        return module2use, name2use\n\n    def get_params(self, module):\n        \"\"\"gets params of reparameterization based on known attribute names\"\"\"\n        return [getattr(module, n) for n in self.reparameterization_names]\n\n    def remove(self, module):\n        \"\"\"removes reparameterization and backward hook (does not remove forward hook)\"\"\"\n        module2use, name2use = Reparameterization.get_module_and_name(module, self.name)\n        for p in self.get_params(module2use):\n            p.requires_grad = False\n        weight = self.compute_weight(module2use, name2use)\n        delattr(module2use, name2use)\n        for n in self.reparameterization_names:\n            del module2use._parameters[n]\n        module2use.register_parameter(name2use, Parameter(weight.data))\n        del module._backward_hooks[self.backward_hook_key]\n\n    def __call__(self, module, inputs):\n        \"\"\"callable hook for forward pass\"\"\"\n        module2use, name2use = Reparameterization.get_module_and_name(module, self.name)\n        _w = getattr(module2use, name2use)\n        if not self.evaluated or _w is None:\n            setattr(module2use, name2use, self.compute_weight(module2use, name2use))\n            self.evaluated = True\n\n    def backward_hook(self, module, grad_input, grad_output):\n        \"\"\"callable hook for backward pass\"\"\"\n        module2use, name2use = Reparameterization.get_module_and_name(module, self.name)\n        wn = getattr(module2use, name2use)\n        self.evaluated = False\n"
  },
  {
    "path": "KoSimCSE/apex/reparameterization/weight_norm.py",
    "content": "import torch\nfrom torch.nn.parameter import Parameter\nfrom ..fp16_utils import Fused_Weight_Norm\nimport time\n\nfrom .reparameterization import Reparameterization\n\ndef _norm(p, dim):\n    \"\"\"Computes the norm over all dimensions except dim\"\"\"\n    if dim is None:\n        return p.norm()\n    elif dim == 0:\n        output_size = (p.size(0),) + (1,) * (p.dim() - 1)\n        return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)\n    elif dim == p.dim() - 1:\n        output_size = (1,) * (p.dim() - 1) + (p.size(-1),)\n        return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)\n    return _norm(p.transpose(0, dim), 0).transpose(0, dim)\n\nHALF_TYPES = (torch.cuda.HalfTensor, torch.HalfTensor)\n\nclass WeightNorm(Reparameterization):\n    r\"\"\"\n    Weight normalization is a reparameterization that decouples the magnitude\n    of a weight tensor from its direction. This replaces the parameter specified\n    by `name` (e.g. \"weight\") with two parameters: one specifying the magnitude\n    (e.g. \"weight_g\") and one specifying the direction (e.g. \"weight_v\").\n    Weight normalization is implemented via a hook that recomputes the weight\n    tensor from the magnitude and direction before every :meth:`~Module.forward`\n    call.\n\n    .. math::\n         \\mathbf{w} = g \\dfrac{\\mathbf{v}}{\\|\\mathbf{v}\\|}\n\n    By default, with `dim=0`, the norm is computed independently per output\n    channel/plane. To compute a norm over the entire weight tensor, use\n    `dim=None`.\n    \"\"\"\n    def compute_weight(self, module=None, name=None):\n        \"\"\"\n        Computes weight normalized weight value to assign value to module attribute\n        with name `name`.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n        Returns:\n            w (Tensor): Tensor object containing value of reparameterized weight\n        \"\"\"\n        if module is None:\n            module = self.module\n        if name is None:\n            name = self.name\n        module, name = Reparameterization.get_module_and_name(module, name)\n        g = getattr(module, name + '_g')\n        v = getattr(module, name + '_v')\n\n        fused_weight_norm = Fused_Weight_Norm.apply\n        v = v.contiguous()\n        w = fused_weight_norm(v, g, self.dim)\n\n        return w\n\n    def reparameterize(self, name, weight, dim):\n        \"\"\"\n        Creates Parameters v and gto be used for weight normalization\n        and creates names that for attributes for the module these Parameters\n        will correspond to. The parameters will be registered according to the names\n        provided.\n        Arguments:\n            module (nn.Module): module with weight we'd like to reparameterize\n            name (str, optional): name of weight parameter\n            dim (int, optional): dimension over which to compute parameterization\n        Returns:\n            names (list, str): names of Parameters to be used for reparameterization\n            params (list, Parameter): Parameters to be used for reparameterization\n        \"\"\"\n        names = [name + '_g', name + '_v']\n        params = [Parameter(_norm(weight, dim).data), Parameter(weight.data)]\n        return names, params\n"
  },
  {
    "path": "KoSimCSE/data/dataloader.py",
    "content": "import numpy\nimport torch\nimport logging\nfrom torch.utils.data import DataLoader, Dataset\nfrom transformers import AutoModel, AutoTokenizer\n\nlogger = logging.getLogger(__name__)\n\n\nclass ModelDataLoader(Dataset):\n    def __init__(self, file_path, args, metric, tokenizer, type_):\n        self.type = type_\n        self.args = args\n        self.metric = metric\n\n        \"\"\"NLI\"\"\"\n        self.anchor = []\n        self.positive = []\n        self.negative = []\n\n        \"\"\"STS\"\"\"\n        self.label = []\n        self.sentence_1 = []\n        self.sentence_2 = []\n\n        #  -------------------------------------\n        self.bert_tokenizer = tokenizer\n        self.file_path = file_path\n\n        \"\"\"\n        [CLS]: 2\n        [PAD]: 0\n        [UNK]: 1\n        \"\"\"\n        self.init_token = self.bert_tokenizer.cls_token\n        self.pad_token = self.bert_tokenizer.pad_token\n        self.unk_token = self.bert_tokenizer.unk_token\n\n        self.init_token_idx = self.bert_tokenizer.convert_tokens_to_ids(self.init_token)\n        self.pad_token_idx = self.bert_tokenizer.convert_tokens_to_ids(self.pad_token)\n        self.unk_token_idx = self.bert_tokenizer.convert_tokens_to_ids(self.unk_token)\n        \n    def load_data(self, type):\n\n        with open(self.file_path) as file:\n            lines = file.readlines()\n\n            for line in lines:\n                self.data2tensor(line, type)\n\n        if type == 'train':\n            assert len(self.anchor) == len(self.positive) == len(self.negative)\n        else:\n            assert len(self.sentence_1) == len(self.sentence_2) == len(self.label)\n\n    def data2tensor(self, line, type):\n        split_data = line.split('\\t')\n\n        if type == 'train':\n            anchor_sen, positive_sen, negative_sen = split_data\n\n            anchor = self.bert_tokenizer(anchor_sen, \n                                         truncation=True,\n                                         return_tensors=\"pt\",\n                                         max_length=self.args.max_len,\n                                         pad_to_max_length=\"right\")\n            \n            positive = self.bert_tokenizer(positive_sen, \n                                           truncation=True,\n                                           return_tensors=\"pt\",\n                                           max_length=self.args.max_len,\n                                           pad_to_max_length=\"right\")\n\n            negative = self.bert_tokenizer(negative_sen, \n                                           truncation=True,\n                                           return_tensors=\"pt\",\n                                           max_length=self.args.max_len,\n                                           pad_to_max_length=\"right\")\n            \n            self.anchor.append(anchor)\n            self.positive.append(positive)\n            self.negative.append(negative)\n\n        else:\n            sentence_1, sentence_2, label = split_data\n\n            sentence_1 = self.bert_tokenizer(sentence_1, \n                                             truncation=True,\n                                             return_tensors=\"pt\",\n                                             max_length=self.args.max_len,\n                                             pad_to_max_length=\"right\")\n\n            sentence_2 = self.bert_tokenizer(sentence_2,\n                                             truncation=True,\n                                             return_tensors=\"pt\",\n                                             max_length=self.args.max_len,\n                                             pad_to_max_length=\"right\")\n\n\n            self.sentence_1.append(sentence_1)\n            self.sentence_2.append(sentence_2)\n            self.label.append(float(label.strip())/5.0)\n\n    def __getitem__(self, index):\n\n        if self.type == 'train':\n            inputs = {'anchor': {\n                'source': torch.LongTensor(self.anchor[index]['input_ids']),\n                'attention_mask': self.anchor[index]['attention_mask'],\n                'token_type_ids': torch.LongTensor(self.anchor[index]['token_type_ids'])\n                },\n                      'positive': {\n                'source': torch.LongTensor(self.positive[index]['input_ids']),\n                'attention_mask': self.positive[index]['attention_mask'],\n                'token_type_ids': torch.LongTensor(self.positive[index]['token_type_ids'])\n                },\n                      'negative': {\n                'source': torch.LongTensor(self.negative[index]['input_ids']),\n                'attention_mask': self.negative[index]['attention_mask'],\n                'token_type_ids': torch.LongTensor(self.negative[index]['token_type_ids'])\n                }}\n        else:\n\n            inputs = {'sentence_1': {\n                'source': torch.LongTensor(self.sentence_1[index]['input_ids']),\n                'attention_mask': self.sentence_1[index]['attention_mask'],\n                'token_type_ids': torch.LongTensor(self.sentence_1[index]['token_type_ids'])\n                },\n                      'sentence_2': {\n                'source': torch.LongTensor(self.sentence_2[index]['input_ids']),\n                'attention_mask': self.sentence_2[index]['attention_mask'],\n                'token_type_ids': torch.LongTensor(self.sentence_2[index]['token_type_ids'])\n                },\n                      'label': {\n                          'value': torch.FloatTensor([self.label[index]])}\n                }\n\n        for key, value in inputs.items():\n            for inner_key, inner_value in value.items():\n                inputs[key][inner_key] = inner_value.squeeze(0)\n                \n        inputs = self.metric.move2device(inputs, self.args.device)\n        \n        return inputs\n\n    def __len__(self):\n        if self.type == 'train':\n            return len(self.anchor)\n        else:\n            return len(self.label)\n\n\n# Get train, valid, test data loader and BERT tokenizer\ndef get_loader(args, metric):\n    \n    tokenizer = AutoTokenizer.from_pretrained(args.model)\n    \n    path_to_train_data = args.path_to_data + '/' + args.train_data\n    path_to_valid_data = args.path_to_data + '/' + args.valid_data\n    path_to_test_data = args.path_to_data + '/' + args.test_data\n\n    if args.train == 'True' and args.test == 'False':\n        train_iter = ModelDataLoader(path_to_train_data, args, metric, tokenizer, type_='train')\n        valid_iter = ModelDataLoader(path_to_valid_data, args, metric, tokenizer, type_='valid')\n\n        train_iter.load_data('train')\n        valid_iter.load_data('valid')\n\n        loader = {'train': DataLoader(dataset=train_iter,\n                                      batch_size=args.batch_size,\n                                      shuffle=True),\n                  'valid': DataLoader(dataset=valid_iter,\n                                      batch_size=args.batch_size,\n                                      shuffle=True)}\n\n    elif args.train == 'False' and args.test == 'True':\n        test_iter = ModelDataLoader(path_to_test_data, args, metric, tokenizer, type_='test')\n        test_iter.load_data('test')\n\n        loader = {'test': DataLoader(dataset=test_iter,\n                                     batch_size=args.batch_size,\n                                     shuffle=True)}\n\n    else:\n        loader = None\n\n    return loader, tokenizer\n\n\ndef convert_to_tensor(corpus, tokenizer, device):\n    inputs = tokenizer(corpus,\n                       truncation=True,\n                       return_tensors=\"pt\",\n                       max_length=50,\n                       pad_to_max_length=\"right\")\n    \n    embedding = inputs['input_ids']\n    attention_mask = inputs['attention_mask']\n    token_type_ids = inputs['token_type_ids']\n        \n    inputs = {'source': torch.LongTensor(embedding).to(device),\n              'token_type_ids': torch.LongTensor(token_type_ids).to(device),\n              'attention_mask': attention_mask.to(device)}\n    \n    return inputs\n\n\ndef example_model_setting(model_ckpt, model_name):\n\n    from model.simcse.bert import BERT\n\n    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n\n    model = BERT(AutoModel.from_pretrained(model_name))\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    model.load_state_dict(torch.load(model_ckpt)['model'])\n    model.to(device)\n    model.eval()\n    \n    return model, tokenizer, device\n\n\nif __name__ == '__main__':\n    get_loader('test')\n"
  },
  {
    "path": "KoSimCSE/main.py",
    "content": "from model.setting import Setting, Arguments\nfrom model.simcse.processor import Processor\n\n\ndef main(args, logger) -> None:\n    processor = Processor(args)\n    config = processor.model_setting()\n    logger.info('Model Setting Complete')\n\n    if args.train == 'True':\n        logger.info('Start Training')\n\n        for epoch in range(args.epochs):\n\n            processor.train(epoch+1)\n\n    if args.test == 'True':\n        logger.info(\"Start Test\")\n\n        processor.test()\n\n        processor.metric.print_size_of_model(config['model'])\n        processor.metric.count_parameters(config['model'])\n\n\nif __name__ == '__main__':\n    args, logger = Setting().run()\n    main(args, logger)\n"
  },
  {
    "path": "KoSimCSE/model/loss.py",
    "content": "import torch\nimport logging\nimport numpy as np\nimport torch.nn as nn\nfrom model.utils import Metric\nfrom scipy.stats import pearsonr, spearmanr\nfrom sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances\n\nlogger = logging.getLogger(__name__)\n\n\nclass Loss():\n\n    def __init__(self, args):\n        self.args = args\n        self.cos = nn.CosineSimilarity(dim=-1)\n        self.metric = Metric(args)\n\n    def train_loss_fct(self, config, inputs, a, p, n):\n         \n        positive_similarity = self.cos(a.unsqueeze(1), p.unsqueeze(0)) / self.args.temperature\n        negative_similarity = self.cos(a.unsqueeze(1), n.unsqueeze(0)) / self.args.temperature\n        cosine_similarity = torch.cat([positive_similarity, negative_similarity], dim=1).to(self.args.device)\n\n        labels = torch.arange(cosine_similarity.size(0)).long().to(self.args.device)\n\n        loss = config['criterion'](cosine_similarity, labels)\n\n        return loss\n\n    def evaluation_during_training(self, embeddings1, embeddings2, labels, indicator):\n\n        embeddings1 = embeddings1.cpu().numpy()\n        embeddings2 = embeddings2.cpu().numpy()\n        labels = labels['value'].cpu().numpy().flatten()\n\n        cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))\n        manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)\n        euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)\n        dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]\n\n        eval_pearson_cosine, _ = pearsonr(labels, cosine_scores)\n        eval_spearman_cosine, _ = spearmanr(labels, cosine_scores)\n\n        eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances)\n        eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances)\n\n        eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances)\n        eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances)\n\n        eval_pearson_dot, _ = pearsonr(labels, dot_products)\n        eval_spearman_dot, _ = spearmanr(labels, dot_products)\n\n        score = {'eval_pearson_cosine': eval_pearson_cosine,\n                 'eval_spearman_cosine': eval_spearman_cosine,\n                 'eval_pearson_manhattan': eval_pearson_manhattan,\n                 'eval_spearman_manhattan': eval_spearman_manhattan,\n                 'eval_pearson_euclidean': eval_pearson_euclidean,\n                 'eval_spearman_euclidean': eval_spearman_euclidean,\n                 'eval_pearson_dot': eval_pearson_dot,\n                 'eval_spearman_dot': eval_spearman_dot}\n\n        self.metric.update_indicator(indicator, score)\n\n        return max(eval_spearman_cosine, eval_spearman_manhattan, eval_spearman_euclidean, eval_spearman_dot)\n"
  },
  {
    "path": "KoSimCSE/model/setting.py",
    "content": "import torch\nimport random\nimport logging\nimport numpy as np\nfrom argparse import ArgumentParser\n\n\nclass Arguments():\n\n    def __init__(self):\n        self.parser = ArgumentParser()\n\n    def add_type_of_processing(self):\n        self.add_argument('--opt_level', type=str, default='O1')\n        self.add_argument('--fp16', type=str, default='True')\n        self.add_argument('--train', type=str, default='True')\n        self.add_argument('--test', type=str, default='True')\n        self.add_argument('--device', type=str, default=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))\n\n    def add_hyper_parameters(self):\n        self.add_argument('--model', type=str, default='klue/bert-base')\n        self.add_argument('--patient', type=int, default=10)\n        self.add_argument('--dropout', type=int, default=0.1)\n        self.add_argument('--max_len', type=int, default=50)\n        self.add_argument('--batch_size', type=int, default=256)\n        self.add_argument('--epochs', type=int, default=3)\n        self.add_argument('--eval_steps', type=int, default=250)\n        self.add_argument('--seed', type=int, default=12)\n        self.add_argument('--lr', type=float, default=0.00005)\n        self.add_argument('--weight_decay', type=float, default=0.1)\n        self.add_argument('--warmup_ratio', type=float, default=0.05)\n        self.add_argument('--temperature', type=float, default=0.05)\n\n    def add_data_parameters(self):\n        self.add_argument('--train_data', type=str, default='train_nli.tsv')\n        self.add_argument('--valid_data', type=str, default='valid_sts.tsv')\n        self.add_argument('--test_data', type=str, default='test_sts.tsv')\n        self.add_argument('--task', type=str, default='NLU')\n        self.add_argument('--path_to_data', type=str, default='./data/')\n        self.add_argument('--path_to_save', type=str, default='./output/')\n        self.add_argument('--path_to_saved_model', type=str, default='./output/')\n        self.add_argument('--ckpt', type=str, default='best_checkpoint.pt')\n\n    def print_args(self, args):\n        for idx, (key, value) in enumerate(args.__dict__.items()):\n            if idx == 0:print(\"argparse{\\n\", \"\\t\", key, \":\", value)\n            elif idx == len(args.__dict__) - 1:print(\"\\t\", key, \":\", value, \"\\n}\")\n            else:print(\"\\t\", key, \":\", value)\n\n    def add_argument(self, *args, **kw_args):\n        return self.parser.add_argument(*args, **kw_args)\n\n    def parse(self):\n        args = self.parser.parse_args()\n        self.print_args(args)\n\n        return args\n\n\nclass Setting():\n\n    def set_logger(self):\n\n        _logger = logging.getLogger()\n        formatter = logging.Formatter(\n            '[%(levelname)s] %(asctime)s [ %(message)s ] | file::%(filename)s | line::%(lineno)s')\n\n        stream_handler = logging.StreamHandler()\n        stream_handler.setFormatter(formatter)\n\n        _logger.addHandler(stream_handler)\n        _logger.setLevel(logging.DEBUG)\n\n        return _logger\n\n    def set_seed(self, args):\n\n        seed = args.seed\n\n        random.seed(seed)\n        np.random.seed(seed)\n\n        torch.manual_seed(seed)\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n        torch.cuda.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def run(self):\n\n        parser = Arguments()\n        parser.add_type_of_processing()\n        parser.add_hyper_parameters()\n        parser.add_data_parameters()\n\n        args = parser.parse()\n        logger = self.set_logger()\n        self.set_seed(args)\n\n        return args, logger\n"
  },
  {
    "path": "KoSimCSE/model/simcse/bert.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass BERT(nn.Module):\n    def __init__(self, bert):\n        super(BERT, self).__init__()\n        self.bert = bert\n\n    def forward(self, config, inputs, mode):\n\n        if mode == 'train':\n            \n            anchor_pooler, _ = self.bert(input_ids=inputs['anchor']['source'],\n                                         token_type_ids=inputs['anchor']['token_type_ids'],\n                                         attention_mask=inputs['anchor']['attention_mask'],\n                                         return_dict=False)\n            \n            positive_pooler, _ = self.bert(input_ids=inputs['positive']['source'],\n                                           token_type_ids=inputs['positive']['token_type_ids'],\n                                           attention_mask=inputs['positive']['attention_mask'],\n                                           return_dict=False)\n\n            negative_pooler, _ = self.bert(input_ids=inputs['negative']['source'],\n                                           token_type_ids=inputs['negative']['token_type_ids'],\n                                           attention_mask=inputs['negative']['attention_mask'],\n                                           return_dict=False)\n            \n            return anchor_pooler[:, 0], positive_pooler[:, 0], negative_pooler[:, 0]\n\n        else:\n            sentence_1_pooler, _ = self.bert(input_ids=inputs['sentence_1']['source'],\n                                             token_type_ids=inputs['sentence_1']['token_type_ids'],\n                                             attention_mask=inputs['sentence_1']['attention_mask'],\n                                             return_dict=False)\n            \n            sentence_2_pooler, _ = self.bert(input_ids=inputs['sentence_2']['source'],\n                                             token_type_ids=inputs['sentence_2']['token_type_ids'],\n                                             attention_mask=inputs['sentence_2']['attention_mask'],\n                                             return_dict=False)\n        \n            return sentence_1_pooler[:, 0], sentence_2_pooler[:, 0]\n\n\n    def encode(self, inputs, device):\n    \n        embeddings, _ = self.bert(input_ids=inputs['source'].to(device),\n                                  token_type_ids=inputs['token_type_ids'].to(device),\n                                  attention_mask=inputs['attention_mask'].to(device),\n                                  return_dict=False)\n\n        return embeddings[:, 0]\n"
  },
  {
    "path": "KoSimCSE/model/simcse/processor.py",
    "content": "import os\nimport logging\nfrom apex import amp\nimport torch.nn as nn\nfrom tqdm import tqdm\nimport torch.quantization\nimport torch.optim as optim\nfrom model.loss import Loss\nfrom model.utils import Metric\nfrom transformers import AutoModel\nfrom model.simcse.bert import BERT\nfrom data.dataloader import get_loader\nfrom transformers import get_linear_schedule_with_warmup\n\nlogger = logging.getLogger(__name__)\n\n\nclass Processor():\n\n    def __init__(self, args):\n        self.args = args\n        self.config = None\n        self.metric = Metric(args)\n        self.loss = Loss(args)\n        self.total_steps = 0\n        self.model_checker = {'early_stop': False,\n                              'early_stop_patient': 0,\n                              'best_valid_score': 0}\n        self.dev_progress = {'score': 0, 'iter': 0}\n        self.model_progress = {'loss': 0, 'iter': 0}\n\n    def run(self, inputs, indicator=None, type=None):\n\n        if type == 'train':\n            anchor_embeddings, positive_embeddings, negative_embeddings = self.config['model'](self.config, inputs, type)\n            loss = self.loss.train_loss_fct(self.config,\n                                            inputs, \n                                            anchor_embeddings, \n                                            positive_embeddings, \n                                            negative_embeddings)\n            return loss\n        else:\n            sentence_1_embeddings, sentence_2_embeddings = self.config['model'](self.config, inputs, type)\n            score = self.loss.evaluation_during_training(sentence_1_embeddings,\n                                                         sentence_2_embeddings,\n                                                         inputs['label'],\n                                                         indicator)\n            return score\n\n    def progress(self, loss):\n        self.model_progress['loss'] += loss\n        self.model_progress['iter'] += 1\n\n    def progress_validation(self, score):\n        self.dev_progress['score'] += score\n        self.dev_progress['iter'] += 1\n\n    def return_value(self):\n        loss = self.model_progress['loss'].data.cpu().numpy() / self.model_progress['iter']\n        acc = self.model_progress['acc'].data.cpu().numpy() / self.model_progress['iter']\n\n        return loss, acc\n\n    def get_object(self, tokenizer, model):\n\n        no_decay = ['bias', 'LayerNorm.weight']\n        optimizer_grouped_parameters = [\n            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n             'weight_decay': self.args.weight_decay},\n            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n             'weight_decay': 0.0}\n        ]\n\n        criterion = nn.CrossEntropyLoss()\n        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.args.lr)\n\n        return criterion, optimizer\n\n    def get_scheduler(self, optim, train_loader):\n        train_total = len(train_loader) * self.args.epochs\n        scheduler = get_linear_schedule_with_warmup(optim,\n                                                    num_warmup_steps=self.args.warmup_ratio * train_total,\n                                                    num_training_steps=train_total)\n\n        return scheduler, train_total\n\n    def model_setting(self):\n        loader, tokenizer = get_loader(self.args, self.metric)\n        model = BERT(AutoModel.from_pretrained(self.args.model))\n        model.to(self.args.device)\n\n        criterion, optimizer = self.get_object(tokenizer, model)\n\n        if self.args.train == 'True':\n            scheduler, total_steps = self.get_scheduler(optimizer, loader['train'])\n            self.total_steps = total_steps\n        else:\n            scheduler = None\n\n        config = {'loader': loader,\n                  'optimizer': optimizer,\n                  'criterion': criterion,\n                  'scheduler': scheduler,\n                  'tokenizer': tokenizer,\n                  'args': self.args,\n                  'model': model}\n\n        if config['args'].fp16 == 'True':\n            config['model'], config['optimizer'] = amp.initialize(\n                config['model'], config['optimizer'], opt_level=config['args'].opt_level)\n\n        self.config = config\n\n        return self.config\n\n    def train(self, epoch):\n        self.config['model'].train()\n\n        for step, batch in enumerate(tqdm(self.config['loader']['train'])):\n            self.config['optimizer'].zero_grad()\n\n            inputs = batch\n\n            train_loss = self.run(inputs, type='train')\n\n            if self.args.fp16 == 'True':\n                with amp.scale_loss(train_loss, self.config['optimizer']) as scaled_loss:\n                    scaled_loss.backward()\n            else:\n                train_loss.backward()\n\n            self.config['optimizer'].step()\n            self.config['scheduler'].step()\n\n            self.progress(train_loss.data)\n\n            if self.model_progress['iter'] % self.args.eval_steps == 0 or self.model_progress['iter'] == self.total_steps:\n                valid_score = self.valid()\n                performance = {'tl': train_loss, 'vs': valid_score, 'ep': epoch, 'step': self.model_progress['iter']}\n                \n                self.metric.save_model(self.config, performance, self.model_checker)\n                self.config['model'].train()\n                \n    def valid(self):\n        self.config['model'].eval()\n        self.dev_progress = self.dev_progress.fromkeys(self.dev_progress, 0)\n\n        score_indicator = {'eval_pearson_cosine': 0,\n                           'eval_spearman_cosine': 0,\n                           'eval_pearson_manhattan': 0,\n                           'eval_spearman_manhattan': 0,\n                           'eval_pearson_euclidean': 0,\n                           'eval_spearman_euclidean': 0,\n                           'eval_pearson_dot': 0,\n                           'eval_spearman_dot': 0}\n\n        with torch.no_grad():\n            for step, batch in enumerate(self.config['loader']['valid']):\n                inputs = batch\n                score = self.run(inputs, indicator=score_indicator, type='valid')\n\n                self.progress_validation(score)\n\n        score = self.metric.cal_dev_score(self.dev_progress, score_indicator)\n\n        return score\n\n    def test(self):\n        self.config['model'].load_state_dict(torch.load(self.args.path_to_saved_model)['model'], strict=False)\n        self.config['model'].eval()\n        self.dev_progress = self.dev_progress.fromkeys(self.dev_progress, 0)\n\n        score_indicator = {'eval_pearson_cosine': 0,\n                           'eval_spearman_cosine': 0,\n                           'eval_pearson_manhattan': 0,\n                           'eval_spearman_manhattan': 0,\n                           'eval_pearson_euclidean': 0,\n                           'eval_spearman_euclidean': 0,\n                           'eval_pearson_dot': 0,\n                           'eval_spearman_dot': 0}\n\n        with torch.no_grad():\n            for step, batch in enumerate(self.config['loader']['test']):\n                inputs = batch\n                score = self.run(inputs, indicator=score_indicator, type='test')\n\n                self.progress_validation(score)\n\n        logger.info('### TEST SCORE ###')\n        score = self.metric.cal_dev_score(self.dev_progress, score_indicator)\n"
  },
  {
    "path": "KoSimCSE/model/utils.py",
    "content": "import os\nimport torch\nimport logging\nfrom tensorboardX import SummaryWriter\n\nlogger = logging.getLogger(__name__)\nwriter = SummaryWriter()\n\n\nclass Metric():\n\n    def __init__(self, args):\n        self.args = args\n\n    def get_lr(self, optimizer):\n        return optimizer.state_dict()['param_groups'][0]['lr']\n\n    def count_parameters(self, model):\n        print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n\n    def cal_acc(self, yhat, y):\n        with torch.no_grad():\n            yhat = yhat.max(dim=-1)[1]\n            acc = (yhat == y).float().mean()\n\n        return acc\n\n    def cal_time(self, start_time, end_time):\n        elapsed_time = end_time - start_time\n        elapsed_mins = int(elapsed_time / 60)\n        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n\n        return elapsed_mins, elapsed_secs\n\n    def cal_dev_score(self, score, indicator):\n        validation_score = score['score'] / score['iter']\n        for key, value in indicator.items():\n            indicator[key] /= score['iter']\n\n        print(\"\\n\\nCosine-Similarity :\\tPearson: {:.4f}\\tSpearman: {:.4f}\".format(\n            indicator['eval_pearson_cosine'], indicator['eval_spearman_cosine']))\n        print(\"Manhattan-Distance:\\tPearson: {:.4f}\\tSpearman: {:.4f}\".format(\n            indicator['eval_pearson_manhattan'], indicator['eval_spearman_manhattan']))\n        print(\"Euclidean-Distance:\\tPearson: {:.4f}\\tSpearman: {:.4f}\".format(\n            indicator['eval_pearson_euclidean'], indicator['eval_spearman_euclidean']))\n        print(\"Dot-Product-Similarity:\\tPearson: {:.4f}\\tSpearman: {:.4f}\\n\".format(\n            indicator['eval_pearson_dot'], indicator['eval_spearman_dot']))\n\n        return validation_score\n\n    def update_indicator(self, indicator, score):\n        for key, value in indicator.items():\n            if key == 'eval_spearman_cosine':\n                indicator[key] += score['eval_spearman_cosine']\n            elif key == 'eval_pearson_cosine':\n                indicator[key] += score['eval_pearson_cosine']\n            elif key == 'eval_spearman_manhattan':\n                indicator[key] += score['eval_spearman_manhattan']\n            elif key == 'eval_pearson_manhattan':\n                indicator[key] += score['eval_pearson_manhattan']\n            elif key == 'eval_spearman_euclidean':\n                indicator[key] += score['eval_spearman_euclidean']\n            elif key == 'eval_pearson_euclidean':\n                indicator[key] += score['eval_pearson_euclidean']\n            elif key == 'eval_spearman_dot':\n                indicator[key] += score['eval_spearman_dot']\n            elif key == 'eval_pearson_dot':\n                indicator[key] += score['eval_pearson_dot']\n\n    def draw_graph(self, cp):\n        writer.add_scalars('loss_graph', {'train': cp['tl'], 'valid': cp['vl']}, cp['ep'])\n        writer.add_scalars('acc_graph', {'train': cp['tma'], 'valid': cp['vma']}, cp['ep'])\n\n    def performance_check(self, cp, config):\n        print(f'\\t==Epoch: {cp[\"ep\"] + 1:02} | Epoch Time: {cp[\"epm\"]}m {cp[\"eps\"]}s==')\n        print(f'\\t==Train Loss: {cp[\"tl\"]:.4f} | Train acc: {cp[\"tma\"]:.4f}==')\n        print(f'\\t==Valid Loss: {cp[\"vl\"]:.4f} | Valid acc: {cp[\"vma\"]:.4f}==')\n        print(f'\\t==Epoch latest LR: {self.get_lr(config[\"optimizer\"]):.9f}==\\n')\n\n    def print_size_of_model(self, model):\n        torch.save(model.state_dict(), \"temp.p\")\n        print('Size (MB):', os.path.getsize(\"temp.p\") / 1e6)\n        os.remove('temp.p')\n\n    def move2device(self, sample, device):\n        if len(sample) == 0:\n            return {}\n\n        def _move_to_device(maybe_tensor, device):\n            if torch.is_tensor(maybe_tensor):\n                return maybe_tensor.to(device)\n            elif isinstance(maybe_tensor, dict):\n                return {\n                    key: _move_to_device(value, device)\n                    for key, value in maybe_tensor.items()\n                    }\n            elif isinstance(maybe_tensor, list):\n                return [_move_to_device(x, device) for x in maybe_tensor]\n            elif isinstance(maybe_tensor, tuple):\n                return [_move_to_device(x, device) for x in maybe_tensor]\n            else:\n                return maybe_tensor\n\n        return _move_to_device(sample, device)\n\n    def save_model(self, config, cp, pco):\n        if not os.path.exists(config['args'].path_to_save):\n            os.makedirs(config['args'].path_to_save)\n\n        sorted_path = config['args'].path_to_save + \"kosimcse-\" + config['args'].model.replace(\"/\", \"-\") + '.pt'\n        if cp['vs'] > pco['best_valid_score']:\n            pco['best_valid_score'] = cp['vs']\n\n            state = {'model': config['model'].state_dict(),\n                     'optimizer': config['optimizer'].state_dict()}\n\n            torch.save(state, sorted_path)\n            print(f'\\t## SAVE {sorted_path} |'\n                  f' valid_score: {cp[\"vs\"]:.4f} |'\n                  f' epochs: {cp[\"ep\"]} |'\n                  f' steps: {cp[\"step\"]} ##\\n')\n\ndef pytorch_cos_sim(a, b):\n    \"\"\"\n    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.\n    This function can be used as a faster replacement for 1-scipy.spatial.distance.cdist(a,b)\n    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])\n    \"\"\"\n    if not isinstance(a, torch.Tensor):\n        a = torch.tensor(a)\n\n    if not isinstance(b, torch.Tensor):\n        b = torch.tensor(b)\n\n    if len(a.shape) == 1:\n        a = a.unsqueeze(0)\n\n    if len(b.shape) == 1:\n        b = b.unsqueeze(0)\n\n    a_norm = a / a.norm(dim=1)[:, None]\n    b_norm = b / b.norm(dim=1)[:, None]\n    return torch.mm(a_norm, b_norm.transpose(0, 1))\n"
  },
  {
    "path": "KoSimCSE/output/empty.txt",
    "content": "##"
  },
  {
    "path": "KoSimCSE/requirements.txt",
    "content": "torch >= 1.7.0\nmxnet >= 1.4.0\ngluonnlp >= 0.6.0\nsentencepiece >= 0.1.6\nonnxruntime >= 0.3.0\ntransformers == 2.8.0\n"
  },
  {
    "path": "KoSimCSE/run_example.sh",
    "content": "#!/bin/bash\n\necho \"Start Training (BERT-BASE)\"\n\nCUDA_VISIBLE_DEVICES=0 python main.py \\\n  --model klue/bert-base \\\n  --test False \\\n  --max_len 50 \\\n  --batch_size 512 \\\n  --epochs 2 \\\n  --eval_steps 250 \\\n  --lr 0.0001 \\\n  --warmup_ratio 0.1 \\\n  --temperature 0.05 \\\n  --path_to_data ../Dataset/ \\\n  --train_data train_nli.tsv \\\n  --valid_data valid_sts.tsv\n\necho \"Start Testing (BERT-BASE)\"\n\nCUDA_VISIBLE_DEVICES=0 python main.py \\\n  --model klue/bert-base \\\n  --train False \\\n  --test True \\\n  --max_len 50 \\\n  --batch_size 512 \\\n  --temperature 0.05 \\\n  --path_to_data ../Dataset/ \\\n  --test_data test_sts.tsv \\\n  --path_to_saved_model output/kosimcse-klue-bert-base.pt\n  \n echo \"Start Training (RoBERTa-BASE)\"\n\nCUDA_VISIBLE_DEVICES=0 python main.py \\\n  --model klue/roberta-base \\\n  --test False \\\n  --max_len 50 \\\n  --batch_size 512 \\\n  --epochs 2 \\\n  --eval_steps 125 \\\n  --lr 0.0001 \\\n  --warmup_ratio 0.2 \\\n  --temperature 0.05 \\\n  --path_to_data ../Dataset/ \\\n  --train_data train_nli.tsv \\\n  --valid_data valid_sts.tsv\n\necho \"Start Testing (RoBERTa-BASE)\"\n\nCUDA_VISIBLE_DEVICES=0 python main.py \\\n  --model klue/roberta-base \\\n  --train False \\\n  --test True \\\n  --max_len 50 \\\n  --batch_size 512 \\\n  --temperature 0.05 \\\n  --path_to_data ../Dataset/ \\\n  --test_data test_sts.tsv \\\n  --path_to_saved_model output/kosimcse-klue-roberta-base.pt\n"
  },
  {
    "path": "LICENSE",
    "content": "Attribution-ShareAlike 4.0 International\n\n=======================================================================\n\nCreative Commons Corporation (\"Creative Commons\") is not a law firm and\ndoes not provide legal services or legal advice. Distribution of\nCreative Commons public licenses does not create a lawyer-client or\nother relationship. Creative Commons makes its licenses and related\ninformation available on an \"as-is\" basis. Creative Commons gives no\nwarranties regarding its licenses, any material licensed under their\nterms and conditions, or any related information. Creative Commons\ndisclaims all liability for damages resulting from their use to the\nfullest extent possible.\n\nUsing Creative Commons Public Licenses\n\nCreative Commons public licenses provide a standard set of terms and\nconditions that creators and other rights holders may use to share\noriginal works of authorship and other material subject to copyright\nand certain other rights specified in the public license below. The\nfollowing considerations are for informational purposes only, are not\nexhaustive, and do not form part of our licenses.\n\n     Considerations for licensors: Our public licenses are\n     intended for use by those authorized to give the public\n     permission to use material in ways otherwise restricted by\n     copyright and certain other rights. Our licenses are\n     irrevocable. Licensors should read and understand the terms\n     and conditions of the license they choose before applying it.\n     Licensors should also secure all rights necessary before\n     applying our licenses so that the public can reuse the\n     material as expected. Licensors should clearly mark any\n     material not subject to the license. This includes other CC-\n     licensed material, or material used under an exception or\n     limitation to copyright. More considerations for licensors:\n     wiki.creativecommons.org/Considerations_for_licensors\n\n     Considerations for the public: By using one of our public\n     licenses, a licensor grants the public permission to use the\n     licensed material under specified terms and conditions. If\n     the licensor's permission is not necessary for any reason--for\n     example, because of any applicable exception or limitation to\n     copyright--then that use is not regulated by the license. Our\n     licenses grant only permissions under copyright and certain\n     other rights that a licensor has authority to grant. Use of\n     the licensed material may still be restricted for other\n     reasons, including because others have copyright or other\n     rights in the material. A licensor may make special requests,\n     such as asking that all changes be marked or described.\n     Although not required by our licenses, you are encouraged to\n     respect those requests where reasonable. More considerations\n     for the public:\n     wiki.creativecommons.org/Considerations_for_licensees\n\n=======================================================================\n\nCreative Commons Attribution-ShareAlike 4.0 International Public\nLicense\n\nBy exercising the Licensed Rights (defined below), You accept and agree\nto be bound by the terms and conditions of this Creative Commons\nAttribution-ShareAlike 4.0 International Public License (\"Public\nLicense\"). To the extent this Public License may be interpreted as a\ncontract, You are granted the Licensed Rights in consideration of Your\nacceptance of these terms and conditions, and the Licensor grants You\nsuch rights in consideration of benefits the Licensor receives from\nmaking the Licensed Material available under these terms and\nconditions.\n\n\nSection 1 -- Definitions.\n\n  a. Adapted Material means material subject to Copyright and Similar\n     Rights that is derived from or based upon the Licensed Material\n     and in which the Licensed Material is translated, altered,\n     arranged, transformed, or otherwise modified in a manner requiring\n     permission under the Copyright and Similar Rights held by the\n     Licensor. For purposes of this Public License, where the Licensed\n     Material is a musical work, performance, or sound recording,\n     Adapted Material is always produced where the Licensed Material is\n     synched in timed relation with a moving image.\n\n  b. Adapter's License means the license You apply to Your Copyright\n     and Similar Rights in Your contributions to Adapted Material in\n     accordance with the terms and conditions of this Public License.\n\n  c. BY-SA Compatible License means a license listed at\n     creativecommons.org/compatiblelicenses, approved by Creative\n     Commons as essentially the equivalent of this Public License.\n\n  d. Copyright and Similar Rights means copyright and/or similar rights\n     closely related to copyright including, without limitation,\n     performance, broadcast, sound recording, and Sui Generis Database\n     Rights, without regard to how the rights are labeled or\n     categorized. For purposes of this Public License, the rights\n     specified in Section 2(b)(1)-(2) are not Copyright and Similar\n     Rights.\n\n  e. Effective Technological Measures means those measures that, in the\n     absence of proper authority, may not be circumvented under laws\n     fulfilling obligations under Article 11 of the WIPO Copyright\n     Treaty adopted on December 20, 1996, and/or similar international\n     agreements.\n\n  f. Exceptions and Limitations means fair use, fair dealing, and/or\n     any other exception or limitation to Copyright and Similar Rights\n     that applies to Your use of the Licensed Material.\n\n  g. License Elements means the license attributes listed in the name\n     of a Creative Commons Public License. The License Elements of this\n     Public License are Attribution and ShareAlike.\n\n  h. Licensed Material means the artistic or literary work, database,\n     or other material to which the Licensor applied this Public\n     License.\n\n  i. Licensed Rights means the rights granted to You subject to the\n     terms and conditions of this Public License, which are limited to\n     all Copyright and Similar Rights that apply to Your use of the\n     Licensed Material and that the Licensor has authority to license.\n\n  j. Licensor means the individual(s) or entity(ies) granting rights\n     under this Public License.\n\n  k. Share means to provide material to the public by any means or\n     process that requires permission under the Licensed Rights, such\n     as reproduction, public display, public performance, distribution,\n     dissemination, communication, or importation, and to make material\n     available to the public including in ways that members of the\n     public may access the material from a place and at a time\n     individually chosen by them.\n\n  l. Sui Generis Database Rights means rights other than copyright\n     resulting from Directive 96/9/EC of the European Parliament and of\n     the Council of 11 March 1996 on the legal protection of databases,\n     as amended and/or succeeded, as well as other essentially\n     equivalent rights anywhere in the world.\n\n  m. You means the individual or entity exercising the Licensed Rights\n     under this Public License. Your has a corresponding meaning.\n\n\nSection 2 -- Scope.\n\n  a. License grant.\n\n       1. Subject to the terms and conditions of this Public License,\n          the Licensor hereby grants You a worldwide, royalty-free,\n          non-sublicensable, non-exclusive, irrevocable license to\n          exercise the Licensed Rights in the Licensed Material to:\n\n            a. reproduce and Share the Licensed Material, in whole or\n               in part; and\n\n            b. produce, reproduce, and Share Adapted Material.\n\n       2. Exceptions and Limitations. For the avoidance of doubt, where\n          Exceptions and Limitations apply to Your use, this Public\n          License does not apply, and You do not need to comply with\n          its terms and conditions.\n\n       3. Term. The term of this Public License is specified in Section\n          6(a).\n\n       4. Media and formats; technical modifications allowed. The\n          Licensor authorizes You to exercise the Licensed Rights in\n          all media and formats whether now known or hereafter created,\n          and to make technical modifications necessary to do so. The\n          Licensor waives and/or agrees not to assert any right or\n          authority to forbid You from making technical modifications\n          necessary to exercise the Licensed Rights, including\n          technical modifications necessary to circumvent Effective\n          Technological Measures. For purposes of this Public License,\n          simply making modifications authorized by this Section 2(a)\n          (4) never produces Adapted Material.\n\n       5. Downstream recipients.\n\n            a. Offer from the Licensor -- Licensed Material. Every\n               recipient of the Licensed Material automatically\n               receives an offer from the Licensor to exercise the\n               Licensed Rights under the terms and conditions of this\n               Public License.\n\n            b. Additional offer from the Licensor -- Adapted Material.\n               Every recipient of Adapted Material from You\n               automatically receives an offer from the Licensor to\n               exercise the Licensed Rights in the Adapted Material\n               under the conditions of the Adapter's License You apply.\n\n            c. No downstream restrictions. You may not offer or impose\n               any additional or different terms or conditions on, or\n               apply any Effective Technological Measures to, the\n               Licensed Material if doing so restricts exercise of the\n               Licensed Rights by any recipient of the Licensed\n               Material.\n\n       6. No endorsement. Nothing in this Public License constitutes or\n          may be construed as permission to assert or imply that You\n          are, or that Your use of the Licensed Material is, connected\n          with, or sponsored, endorsed, or granted official status by,\n          the Licensor or others designated to receive attribution as\n          provided in Section 3(a)(1)(A)(i).\n\n  b. Other rights.\n\n       1. Moral rights, such as the right of integrity, are not\n          licensed under this Public License, nor are publicity,\n          privacy, and/or other similar personality rights; however, to\n          the extent possible, the Licensor waives and/or agrees not to\n          assert any such rights held by the Licensor to the limited\n          extent necessary to allow You to exercise the Licensed\n          Rights, but not otherwise.\n\n       2. Patent and trademark rights are not licensed under this\n          Public License.\n\n       3. To the extent possible, the Licensor waives any right to\n          collect royalties from You for the exercise of the Licensed\n          Rights, whether directly or through a collecting society\n          under any voluntary or waivable statutory or compulsory\n          licensing scheme. In all other cases the Licensor expressly\n          reserves any right to collect such royalties.\n\n\nSection 3 -- License Conditions.\n\nYour exercise of the Licensed Rights is expressly made subject to the\nfollowing conditions.\n\n  a. Attribution.\n\n       1. If You Share the Licensed Material (including in modified\n          form), You must:\n\n            a. retain the following if it is supplied by the Licensor\n               with the Licensed Material:\n\n                 i. identification of the creator(s) of the Licensed\n                    Material and any others designated to receive\n                    attribution, in any reasonable manner requested by\n                    the Licensor (including by pseudonym if\n                    designated);\n\n                ii. a copyright notice;\n\n               iii. a notice that refers to this Public License;\n\n                iv. a notice that refers to the disclaimer of\n                    warranties;\n\n                 v. a URI or hyperlink to the Licensed Material to the\n                    extent reasonably practicable;\n\n            b. indicate if You modified the Licensed Material and\n               retain an indication of any previous modifications; and\n\n            c. indicate the Licensed Material is licensed under this\n               Public License, and include the text of, or the URI or\n               hyperlink to, this Public License.\n\n       2. You may satisfy the conditions in Section 3(a)(1) in any\n          reasonable manner based on the medium, means, and context in\n          which You Share the Licensed Material. For example, it may be\n          reasonable to satisfy the conditions by providing a URI or\n          hyperlink to a resource that includes the required\n          information.\n\n       3. If requested by the Licensor, You must remove any of the\n          information required by Section 3(a)(1)(A) to the extent\n          reasonably practicable.\n\n  b. ShareAlike.\n\n     In addition to the conditions in Section 3(a), if You Share\n     Adapted Material You produce, the following conditions also apply.\n\n       1. The Adapter's License You apply must be a Creative Commons\n          license with the same License Elements, this version or\n          later, or a BY-SA Compatible License.\n\n       2. You must include the text of, or the URI or hyperlink to, the\n          Adapter's License You apply. You may satisfy this condition\n          in any reasonable manner based on the medium, means, and\n          context in which You Share Adapted Material.\n\n       3. You may not offer or impose any additional or different terms\n          or conditions on, or apply any Effective Technological\n          Measures to, Adapted Material that restrict exercise of the\n          rights granted under the Adapter's License You apply.\n\n\nSection 4 -- Sui Generis Database Rights.\n\nWhere the Licensed Rights include Sui Generis Database Rights that\napply to Your use of the Licensed Material:\n\n  a. for the avoidance of doubt, Section 2(a)(1) grants You the right\n     to extract, reuse, reproduce, and Share all or a substantial\n     portion of the contents of the database;\n\n  b. if You include all or a substantial portion of the database\n     contents in a database in which You have Sui Generis Database\n     Rights, then the database in which You have Sui Generis Database\n     Rights (but not its individual contents) is Adapted Material,\n\n     including for purposes of Section 3(b); and\n  c. You must comply with the conditions in Section 3(a) if You Share\n     all or a substantial portion of the contents of the database.\n\nFor the avoidance of doubt, this Section 4 supplements and does not\nreplace Your obligations under this Public License where the Licensed\nRights include other Copyright and Similar Rights.\n\n\nSection 5 -- Disclaimer of Warranties and Limitation of Liability.\n\n  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE\n     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS\n     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF\n     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,\n     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,\n     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR\n     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,\n     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT\n     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT\n     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.\n\n  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE\n     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,\n     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,\n     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,\n     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR\n     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN\n     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR\n     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR\n     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.\n\n  c. The disclaimer of warranties and limitation of liability provided\n     above shall be interpreted in a manner that, to the extent\n     possible, most closely approximates an absolute disclaimer and\n     waiver of all liability.\n\n\nSection 6 -- Term and Termination.\n\n  a. This Public License applies for the term of the Copyright and\n     Similar Rights licensed here. However, if You fail to comply with\n     this Public License, then Your rights under this Public License\n     terminate automatically.\n\n  b. Where Your right to use the Licensed Material has terminated under\n     Section 6(a), it reinstates:\n\n       1. automatically as of the date the violation is cured, provided\n          it is cured within 30 days of Your discovery of the\n          violation; or\n\n       2. upon express reinstatement by the Licensor.\n\n     For the avoidance of doubt, this Section 6(b) does not affect any\n     right the Licensor may have to seek remedies for Your violations\n     of this Public License.\n\n  c. For the avoidance of doubt, the Licensor may also offer the\n     Licensed Material under separate terms or conditions or stop\n     distributing the Licensed Material at any time; however, doing so\n     will not terminate this Public License.\n\n  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public\n     License.\n\n\nSection 7 -- Other Terms and Conditions.\n\n  a. The Licensor shall not be bound by any additional or different\n     terms or conditions communicated by You unless expressly agreed.\n\n  b. Any arrangements, understandings, or agreements regarding the\n     Licensed Material not stated herein are separate from and\n     independent of the terms and conditions of this Public License.\n\n\nSection 8 -- Interpretation.\n\n  a. For the avoidance of doubt, this Public License does not, and\n     shall not be interpreted to, reduce, limit, restrict, or impose\n     conditions on any use of the Licensed Material that could lawfully\n     be made without permission under this Public License.\n\n  b. To the extent possible, if any provision of this Public License is\n     deemed unenforceable, it shall be automatically reformed to the\n     minimum extent necessary to make it enforceable. If the provision\n     cannot be reformed, it shall be severed from this Public License\n     without affecting the enforceability of the remaining terms and\n     conditions.\n\n  c. No term or condition of this Public License will be waived and no\n     failure to comply consented to unless expressly agreed to by the\n     Licensor.\n\n  d. Nothing in this Public License constitutes or may be interpreted\n     as a limitation upon, or waiver of, any privileges and immunities\n     that apply to the Licensor or You, including from the legal\n     processes of any jurisdiction or authority.\n\n\n=======================================================================\n\nCreative Commons is not a party to its public licenses.\nNotwithstanding, Creative Commons may elect to apply one of its public\nlicenses to material it publishes and in those instances will be\nconsidered the “Licensor.” The text of the Creative Commons public\nlicenses is dedicated to the public domain under the CC0 Public Domain\nDedication. Except for the limited purpose of indicating that material\nis shared under a Creative Commons public license or as otherwise\npermitted by the Creative Commons policies published at\ncreativecommons.org/policies, Creative Commons does not authorize the\nuse of the trademark \"Creative Commons\" or any other trademark or logo\nof Creative Commons without its prior written consent including,\nwithout limitation, in connection with any unauthorized modifications\nto any of its public licenses or any other arrangements,\nunderstandings, or agreements concerning use of licensed material. For\nthe avoidance of doubt, this paragraph does not form part of the public\nlicenses.\n\nCreative Commons may be contacted at creativecommons.org.\n"
  },
  {
    "path": "README.md",
    "content": "# Korean-Sentence-Embedding\nThe Korean Sentence Embedding Repository offers pre-trained models that can be easily downloaded and used immediately. Additionally, it provides an optimized environment for customized model training.\n\n## Quick tour\n> **Note** <br>\n> All the pretrained models are uploaded in Huggingface Model Hub. Check https://huggingface.co/BM-K\n```python\nimport torch\nfrom transformers import AutoModel, AutoTokenizer\n\ndef cal_score(a, b):\n    if len(a.shape) == 1: a = a.unsqueeze(0)\n    if len(b.shape) == 1: b = b.unsqueeze(0)\n\n    a_norm = a / a.norm(dim=1)[:, None]\n    b_norm = b / b.norm(dim=1)[:, None]\n    return torch.mm(a_norm, b_norm.transpose(0, 1)) * 100\n\nmodel = AutoModel.from_pretrained('BM-K/KoSimCSE-roberta-multitask')  # or 'BM-K/KoSimCSE-bert-multitask'\ntokenizer = AutoTokenizer.from_pretrained('BM-K/KoSimCSE-roberta-multitask')  # or 'BM-K/KoSimCSE-bert-multitask'\n\nsentences = ['치타가 들판을 가로 질러 먹이를 쫓는다.',\n             '치타 한 마리가 먹이 뒤에서 달리고 있다.',\n             '원숭이 한 마리가 드럼을 연주한다.']\n\ninputs = tokenizer(sentences, padding=True, truncation=True, return_tensors=\"pt\")\nembeddings, _ = model(**inputs, return_dict=False)\n\nscore01 = cal_score(embeddings[0][0], embeddings[1][0])  # 84.09\n# '치타가 들판을 가로 질러 먹이를 쫓는다.' @ '치타 한 마리가 먹이 뒤에서 달리고 있다.'\nscore02 = cal_score(embeddings[0][0], embeddings[2][0])  # 23.21\n# '치타가 들판을 가로 질러 먹이를 쫓는다.' @ '원숭이 한 마리가 드럼을 연주한다.'\n```\n\n## Update history\n\n** Updates on Mar.08.2023 **\n- Update Unsupervised Models\n\n** Updates on Feb.24.2023 **\n- Upload KoSimCSE clustering example\n\n** Updates on Nov.15.2022 **\n- Upload KoDiffCSE-unsupervised training code\n\n** Updates on Oct.27.2022 **\n- Upload KoDiffCSE-unsupervised performance\n\n** Updates on Oct.21.2022 **\n- Upload KoSimCSE-unsupervised performance\n\n** Updates on Jun.01.2022 **\n- Release KoSimCSE-multitask models\n\n** Updates on May.23.2022 **\n- Upload KoSentenceT5 training code\n- Upload KoSentenceT5 performance\n\n** Updates on Mar.01.2022 **\n- Release KoSimCSE\n\n** Updates on Feb.11.2022 **\n- Upload KoSimCSE training code\n- Upload KoSimCSE performance\n\n** Updates on Jan.26.2022 **\n- Upload KoSBERT training code\n- Upload KoSBERT performance\n\n## Baseline Models\nBaseline models used for korean sentence embedding - [KLUE-PLMs](https://github.com/KLUE-benchmark/KLUE/blob/main/README.md)\n\n| Model                | Embedding size | Hidden size | # Layers | # Heads |\n|----------------------|----------------|-------------|----------|---------|\n| KLUE-BERT-base            | 768            | 768         | 12       | 12      |\n| KLUE-RoBERTa-base         | 768            | 768         | 12       | 12      |\n\n> **Warning** <br>\n> Large pre-trained models need a lot of GPU memory to train\n\n## Available Models\n1. Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks [[SBERT]-[EMNLP 2019]](https://arxiv.org/abs/1908.10084)\n2. SimCSE: Simple Contrastive Learning of Sentence Embeddings [[SimCSE]-[EMNLP 2021]](https://arxiv.org/abs/2104.08821)\n3. Sentence-T5: Scalable Sentence Encoders from Pre-trained Text-to-Text Models [[Sentence-T5]-[ACL findings 2022]](https://arxiv.org/abs/2108.08877)\n4. DiffCSE: Difference-based Contrastive Learning for Sentence Embeddings [[DiffCSE]-[NAACL 2022]](https://arxiv.org/abs/2204.10298)\n\n## Datasets\n- [kakaobrain KorNLU Datasets](https://github.com/kakaobrain/KorNLUDatasets) (Supervised setting)\n- [wiki-corpus](https://github.com/jeongukjae/korean-wikipedia-corpus) (Unsupervised setting)\n\n## Setups\n[![Python](https://img.shields.io/badge/python-3.8.5-blue?logo=python&logoColor=FED643)](https://www.python.org/downloads/release/python-385/)\n[![Pytorch](https://img.shields.io/badge/pytorch-1.7.1-red?logo=pytorch)](https://pytorch.org/get-started/previous-versions/)\n\n### KoSentenceBERT\n- 🤗 [Model Training](https://github.com/BM-K/Sentence-Embedding-is-all-you-need/tree/main/KoSBERT)\n- Dataset (Supervised)\n    - Training: snli_1.0_train.ko.tsv, sts-train.tsv (multi-task)\n      - Performance can be further improved by adding multinli data to training.\n    - Validation: sts-dev.tsv\n    - Test: sts-test.tsv\n\n### KoSimCSE\n- 🤗 [Model Training](https://github.com/BM-K/Sentence-Embedding-is-all-you-need/tree/main/KoSimCSE)\n- Dataset (Supervised)\n    - Training: snli_1.0_train.ko.tsv + multinli.train.ko.tsv (Supervised setting)\n    - Validation: sts-dev.tsv\n    - Test: sts-test.tsv\n- Dataset (Unsupervised)\n    - Training: wiki_corpus.txt\n    - Validation: sts-dev.tsv\n    - Test: sts-test.tsv\n\n### KoSentenceT5\n- 🤗 [Model Training](https://github.com/BM-K/Sentence-Embedding-is-all-you-need/tree/main/KoSentenceT5)\n- Dataset (Supervised)\n    - Training: snli_1.0_train.ko.tsv + multinli.train.ko.tsv\n    - Validation: sts-dev.tsv\n    - Test: sts-test.tsv\n\n### KoDiffCSE\n- 🤗 [Model Training](https://github.com/BM-K/KoDiffCSE)\n- Dataset (Unsupervised)\n    - Training: wiki_corpus.txt\n    - Validation: sts-dev.tsv\n    - Test: sts-test.tsv\n    \n## Performance-supervised\n\n| Model                  | Average | Cosine Pearson | Cosine Spearman | Euclidean Pearson | Euclidean Spearman | Manhattan Pearson | Manhattan Spearman | Dot Pearson | Dot Spearman |\n|------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|\n| KoSBERT<sup>†</sup><sub>SKT</sub>    | 77.40 | 78.81 | 78.47 | 77.68 | 77.78 | 77.71 | 77.83 | 75.75 | 75.22 |\n| KoSBERT              | 80.39 | 82.13 | 82.25 | 80.67 | 80.75 | 80.69 | 80.78 | 77.96 | 77.90 |\n| KoSRoBERTa           | 81.64 | 81.20 | 82.20 | 81.79 | 82.34 | 81.59 | 82.20 | 80.62 | 81.25 |\n| | | | | | | | | |\n| KoSentenceBART         | 77.14 | 79.71 | 78.74 | 78.42 | 78.02 | 78.40 | 78.00 | 74.24 | 72.15 |\n| KoSentenceT5          | 77.83 | 80.87 | 79.74 | 80.24 | 79.36 | 80.19 | 79.27 | 72.81 | 70.17 |\n| | | | | | | | | |\n| KoSimCSE-BERT<sup>†</sup><sub>SKT</sub>   | 81.32 | 82.12 | 82.56 | 81.84 | 81.63 | 81.99 | 81.74 | 79.55 | 79.19 |\n| KoSimCSE-BERT              | 83.37 | 83.22 | 83.58 | 83.24 | 83.60 | 83.15 | 83.54 | 83.13 | 83.49 |\n| KoSimCSE-RoBERTa          | 83.65 | 83.60 | 83.77 | 83.54 | 83.76 | 83.55 | 83.77 | 83.55 | 83.64 |\n| | | | | | | | | | |\n| KoSimCSE-BERT-multitask              | 85.71 | 85.29 | 86.02 | 85.63 | 86.01 | 85.57 | 85.97 | 85.26 | 85.93 |\n| KoSimCSE-RoBERTa-multitask          | 85.77 | 85.08 | 86.12 | 85.84 | 86.12 | 85.83 | 86.12 | 85.03 | 85.99 |\n\n- [KoSBERT<sup>†</sup><sub>SKT</sub>](https://github.com/BM-K/KoSentenceBERT-SKT)\n- [KoSimCSE-BERT<sup>†</sup><sub>SKT</sub>](https://github.com/BM-K/KoSimCSE-SKT)\n\n## Performance-unsupervised\n\n| Model                  | Average | Cosine Pearson | Cosine Spearman | Euclidean Pearson | Euclidean Spearman | Manhattan Pearson | Manhattan Spearman | Dot Pearson | Dot Spearman |\n|------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|\n| KoSRoBERTa-base<sup>†</sup>    | N/A | N/A | 48.96 | N/A | N/A | N/A | N/A | N/A | N/A |\n| KoSRoBERTa-large<sup>†</sup>    | N/A | N/A | 51.35 | N/A | N/A | N/A | N/A | N/A | N/A |\n| | | | | | | | | | |\n| KoSimCSE-BERT    | 74.08 | 74.92 | 73.98 | 74.15 | 74.22 | 74.07 | 74.07 | 74.15 | 73.14 |\n| KoSimCSE-RoBERTa    | 75.27 | 75.93 | 75.00 | 75.28 | 75.01 | 75.17 | 74.83 | 75.95 | 75.01 |\n| | | | | | | | | | |\n| KoDiffCSE-RoBERTa    | 77.17 | 77.73 | 76.96 | 77.21 | 76.89 | 77.11 | 76.81 | 77.74 | 76.97 |\n\n- [Korean-SRoBERTa<sup>†</sup>](https://arxiv.org/abs/2004.03289)\n\n## Downstream tasks\n- KoSBERT: [Semantic Search](https://github.com/BM-K/Sentence-Embedding-is-all-you-need/tree/main/KoSBERT#semantic-search), [Clustering](https://github.com/BM-K/Sentence-Embedding-is-all-you-need/tree/main/KoSBERT#clustering)\n- KoSimCSE: [Semantic Search](https://github.com/BM-K/Sentence-Embedding-is-all-you-need/tree/main/KoSimCSE#semantic-search), [Clustering](https://github.com/BM-K/Sentence-Embedding-Is-All-You-Need/tree/main/KoSimCSE#clustering)\n\n## License\nThis work is licensed under a <a rel=\"license\" href=\"http://creativecommons.org/licenses/by-sa/4.0/\">Creative Commons Attribution-ShareAlike 4.0 International License</a>.\n\n<a rel=\"license\" href=\"http://creativecommons.org/licenses/by-sa/4.0/\"><img alt=\"Creative Commons License\" style=\"border-width:0\" src=\"https://i.creativecommons.org/l/by-sa/4.0/88x31.png\" /></a><br />\n\n## References\n\n```bibtex\n@misc{park2021klue,\n    title={KLUE: Korean Language Understanding Evaluation},\n    author={Sungjoon Park and Jihyung Moon and Sungdong Kim and Won Ik Cho and Jiyoon Han and Jangwon Park and Chisung Song and Junseong Kim and Yongsook Song and Taehwan Oh and Joohong Lee and Juhyun Oh and Sungwon Lyu and Younghoon Jeong and Inkwon Lee and Sangwoo Seo and Dongjun Lee and Hyunwoo Kim and Myeonghwa Lee and Seongbo Jang and Seungwon Do and Sunkyoung Kim and Kyungtae Lim and Jongwon Lee and Kyumin Park and Jamin Shin and Seonghyun Kim and Lucy Park and Alice Oh and Jung-Woo Ha and Kyunghyun Cho},\n    year={2021},\n    eprint={2105.09680},\n    archivePrefix={arXiv},\n    primaryClass={cs.CL}\n}\n\n@inproceedings{gao2021simcse,\n   title={{SimCSE}: Simple Contrastive Learning of Sentence Embeddings},\n   author={Gao, Tianyu and Yao, Xingcheng and Chen, Danqi},\n   booktitle={Empirical Methods in Natural Language Processing (EMNLP)},\n   year={2021}\n}\n\n@article{ham2020kornli,\n  title={KorNLI and KorSTS: New Benchmark Datasets for Korean Natural Language Understanding},\n  author={Ham, Jiyeon and Choe, Yo Joong and Park, Kyubyong and Choi, Ilji and Soh, Hyungjoon},\n  journal={arXiv preprint arXiv:2004.03289},\n  year={2020}\n}\n\n@inproceedings{reimers-2019-sentence-bert,\n    title = \"Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks\",\n    author = \"Reimers, Nils and Gurevych, Iryna\",\n    booktitle = \"Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing\",\n    month = \"11\",\n    year = \"2019\",\n    publisher = \"Association for Computational Linguistics\",\n    url = \"http://arxiv.org/abs/1908.10084\",\n}\n\n@inproceedings{chuang2022diffcse,\n   title={{DiffCSE}: Difference-based Contrastive Learning for Sentence Embeddings},\n   author={Chuang, Yung-Sung and Dangovski, Rumen and Luo, Hongyin and Zhang, Yang and Chang, Shiyu and Soljacic, Marin and Li, Shang-Wen and Yih, Wen-tau and Kim, Yoon and Glass, James},\n   booktitle={Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL)},\n   year={2022}\n}\n```\n"
  },
  {
    "path": "get_model_checkpoint.sh",
    "content": "#!/bin/bash\npip install gdown\ngdown --folder https://drive.google.com/drive/folders/https://drive.google.com/drive/folders/1orQxudCmdOLvRUFJdWEEK31l7IOVUnWs?usp=sharing -O Checkpoint\n"
  },
  {
    "path": "get_model_dataset.sh",
    "content": "#!/bin/bash\npip install gdown\ngdown --folder https://drive.google.com/drive/folders/https://drive.google.com/drive/folders/140QpBbBPWXlqsbGZM1SpzhKDLMznq39B?usp=sharing -O Dataset\n"
  }
]