[
  {
    "path": ".dvc/.gitignore",
    "content": "/config.local\n/tmp\n/cache\n"
  },
  {
    "path": ".dvc/config",
    "content": "[core]\n    remote = model-store\n['remote \"storage\"']\n    url = gdrive://19JK5AFbqOBlrFVwDHjTrf9uvQFtS0954\n['remote \"model-store\"']\n    url = s3://models-dvc/trained_models/\n"
  },
  {
    "path": ".dvc/plots/confusion.json",
    "content": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n    },\n    \"title\": \"<DVC_METRIC_TITLE>\",\n    \"facet\": {\n        \"field\": \"rev\",\n        \"type\": \"nominal\"\n    },\n    \"spec\": {\n        \"transform\": [\n            {\n                \"aggregate\": [\n                    {\n                        \"op\": \"count\",\n                        \"as\": \"xy_count\"\n                    }\n                ],\n                \"groupby\": [\n                    \"<DVC_METRIC_Y>\",\n                    \"<DVC_METRIC_X>\"\n                ]\n            },\n            {\n                \"impute\": \"xy_count\",\n                \"groupby\": [\n                    \"rev\",\n                    \"<DVC_METRIC_Y>\"\n                ],\n                \"key\": \"<DVC_METRIC_X>\",\n                \"value\": 0\n            },\n            {\n                \"impute\": \"xy_count\",\n                \"groupby\": [\n                    \"rev\",\n                    \"<DVC_METRIC_X>\"\n                ],\n                \"key\": \"<DVC_METRIC_Y>\",\n                \"value\": 0\n            },\n            {\n                \"joinaggregate\": [\n                    {\n                        \"op\": \"max\",\n                        \"field\": \"xy_count\",\n                        \"as\": \"max_count\"\n                    }\n                ],\n                \"groupby\": []\n            },\n            {\n                \"calculate\": \"datum.xy_count / datum.max_count\",\n                \"as\": \"percent_of_max\"\n            }\n        ],\n        \"encoding\": {\n            \"x\": {\n                \"field\": \"<DVC_METRIC_X>\",\n                \"type\": \"nominal\",\n                \"sort\": \"ascending\",\n                \"title\": \"<DVC_METRIC_X_LABEL>\"\n            },\n            \"y\": {\n                \"field\": \"<DVC_METRIC_Y>\",\n                \"type\": \"nominal\",\n                \"sort\": \"ascending\",\n                \"title\": \"<DVC_METRIC_Y_LABEL>\"\n            }\n        },\n        \"layer\": [\n            {\n                \"mark\": \"rect\",\n                \"width\": 300,\n                \"height\": 300,\n                \"encoding\": {\n                    \"color\": {\n                        \"field\": \"xy_count\",\n                        \"type\": \"quantitative\",\n                        \"title\": \"\",\n                        \"scale\": {\n                            \"domainMin\": 0,\n                            \"nice\": true\n                        }\n                    }\n                }\n            },\n            {\n                \"mark\": \"text\",\n                \"encoding\": {\n                    \"text\": {\n                        \"field\": \"xy_count\",\n                        \"type\": \"quantitative\"\n                    },\n                    \"color\": {\n                        \"condition\": {\n                            \"test\": \"datum.percent_of_max > 0.5\",\n                            \"value\": \"white\"\n                        },\n                        \"value\": \"black\"\n                    }\n                }\n            }\n        ]\n    }\n}\n"
  },
  {
    "path": ".dvc/plots/confusion_normalized.json",
    "content": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n    },\n    \"title\": \"<DVC_METRIC_TITLE>\",\n    \"facet\": {\n        \"field\": \"rev\",\n        \"type\": \"nominal\"\n    },\n    \"spec\": {\n        \"transform\": [\n            {\n                \"aggregate\": [\n                    {\n                        \"op\": \"count\",\n                        \"as\": \"xy_count\"\n                    }\n                ],\n                \"groupby\": [\n                    \"<DVC_METRIC_Y>\",\n                    \"<DVC_METRIC_X>\"\n                ]\n            },\n            {\n                \"impute\": \"xy_count\",\n                \"groupby\": [\n                    \"rev\",\n                    \"<DVC_METRIC_Y>\"\n                ],\n                \"key\": \"<DVC_METRIC_X>\",\n                \"value\": 0\n            },\n            {\n                \"impute\": \"xy_count\",\n                \"groupby\": [\n                    \"rev\",\n                    \"<DVC_METRIC_X>\"\n                ],\n                \"key\": \"<DVC_METRIC_Y>\",\n                \"value\": 0\n            },\n            {\n                \"joinaggregate\": [\n                    {\n                        \"op\": \"sum\",\n                        \"field\": \"xy_count\",\n                        \"as\": \"sum_y\"\n                    }\n                ],\n                \"groupby\": [\n                    \"<DVC_METRIC_Y>\"\n                ]\n            },\n            {\n                \"calculate\": \"datum.xy_count / datum.sum_y\",\n                \"as\": \"percent_of_y\"\n            }\n        ],\n        \"encoding\": {\n            \"x\": {\n                \"field\": \"<DVC_METRIC_X>\",\n                \"type\": \"nominal\",\n                \"sort\": \"ascending\",\n                \"title\": \"<DVC_METRIC_X_LABEL>\"\n            },\n            \"y\": {\n                \"field\": \"<DVC_METRIC_Y>\",\n                \"type\": \"nominal\",\n                \"sort\": \"ascending\",\n                \"title\": \"<DVC_METRIC_Y_LABEL>\"\n            }\n        },\n        \"layer\": [\n            {\n                \"mark\": \"rect\",\n                \"width\": 300,\n                \"height\": 300,\n                \"encoding\": {\n                    \"color\": {\n                        \"field\": \"percent_of_y\",\n                        \"type\": \"quantitative\",\n                        \"title\": \"\",\n                        \"scale\": {\n                            \"domain\": [\n                                0,\n                                1\n                            ]\n                        }\n                    }\n                }\n            },\n            {\n                \"mark\": \"text\",\n                \"encoding\": {\n                    \"text\": {\n                        \"field\": \"percent_of_y\",\n                        \"type\": \"quantitative\",\n                        \"format\": \".2f\"\n                    },\n                    \"color\": {\n                        \"condition\": {\n                            \"test\": \"datum.percent_of_y > 0.5\",\n                            \"value\": \"white\"\n                        },\n                        \"value\": \"black\"\n                    }\n                }\n            }\n        ]\n    }\n}\n"
  },
  {
    "path": ".dvc/plots/default.json",
    "content": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n    },\n    \"title\": \"<DVC_METRIC_TITLE>\",\n    \"width\": 300,\n    \"height\": 300,\n    \"mark\": {\n        \"type\": \"line\"\n    },\n    \"encoding\": {\n        \"x\": {\n            \"field\": \"<DVC_METRIC_X>\",\n            \"type\": \"quantitative\",\n            \"title\": \"<DVC_METRIC_X_LABEL>\"\n        },\n        \"y\": {\n            \"field\": \"<DVC_METRIC_Y>\",\n            \"type\": \"quantitative\",\n            \"title\": \"<DVC_METRIC_Y_LABEL>\",\n            \"scale\": {\n                \"zero\": false\n            }\n        },\n        \"color\": {\n            \"field\": \"rev\",\n            \"type\": \"nominal\"\n        }\n    }\n}\n"
  },
  {
    "path": ".dvc/plots/linear.json",
    "content": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n    },\n    \"title\": \"<DVC_METRIC_TITLE>\",\n    \"width\": 300,\n    \"height\": 300,\n    \"layer\": [\n        {\n            \"encoding\": {\n                \"x\": {\n                    \"field\": \"<DVC_METRIC_X>\",\n                    \"type\": \"quantitative\",\n                    \"title\": \"<DVC_METRIC_X_LABEL>\"\n                },\n                \"y\": {\n                    \"field\": \"<DVC_METRIC_Y>\",\n                    \"type\": \"quantitative\",\n                    \"title\": \"<DVC_METRIC_Y_LABEL>\",\n                    \"scale\": {\n                        \"zero\": false\n                    }\n                },\n                \"color\": {\n                    \"field\": \"rev\",\n                    \"type\": \"nominal\"\n                }\n            },\n            \"layer\": [\n                {\n                    \"mark\": \"line\"\n                },\n                {\n                    \"selection\": {\n                        \"label\": {\n                            \"type\": \"single\",\n                            \"nearest\": true,\n                            \"on\": \"mouseover\",\n                            \"encodings\": [\n                                \"x\"\n                            ],\n                            \"empty\": \"none\",\n                            \"clear\": \"mouseout\"\n                        }\n                    },\n                    \"mark\": \"point\",\n                    \"encoding\": {\n                        \"opacity\": {\n                            \"condition\": {\n                                \"selection\": \"label\",\n                                \"value\": 1\n                            },\n                            \"value\": 0\n                        }\n                    }\n                }\n            ]\n        },\n        {\n            \"transform\": [\n                {\n                    \"filter\": {\n                        \"selection\": \"label\"\n                    }\n                }\n            ],\n            \"layer\": [\n                {\n                    \"mark\": {\n                        \"type\": \"rule\",\n                        \"color\": \"gray\"\n                    },\n                    \"encoding\": {\n                        \"x\": {\n                            \"field\": \"<DVC_METRIC_X>\",\n                            \"type\": \"quantitative\"\n                        }\n                    }\n                },\n                {\n                    \"encoding\": {\n                        \"text\": {\n                            \"type\": \"quantitative\",\n                            \"field\": \"<DVC_METRIC_Y>\"\n                        },\n                        \"x\": {\n                            \"field\": \"<DVC_METRIC_X>\",\n                            \"type\": \"quantitative\"\n                        },\n                        \"y\": {\n                            \"field\": \"<DVC_METRIC_Y>\",\n                            \"type\": \"quantitative\"\n                        }\n                    },\n                    \"layer\": [\n                        {\n                            \"mark\": {\n                                \"type\": \"text\",\n                                \"align\": \"left\",\n                                \"dx\": 5,\n                                \"dy\": -5\n                            },\n                            \"encoding\": {\n                                \"color\": {\n                                    \"type\": \"nominal\",\n                                    \"field\": \"rev\"\n                                }\n                            }\n                        }\n                    ]\n                }\n            ]\n        }\n    ]\n}\n"
  },
  {
    "path": ".dvc/plots/scatter.json",
    "content": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n    },\n    \"title\": \"<DVC_METRIC_TITLE>\",\n    \"width\": 300,\n    \"height\": 300,\n    \"layer\": [\n        {\n            \"encoding\": {\n                \"x\": {\n                    \"field\": \"<DVC_METRIC_X>\",\n                    \"type\": \"quantitative\",\n                    \"title\": \"<DVC_METRIC_X_LABEL>\"\n                },\n                \"y\": {\n                    \"field\": \"<DVC_METRIC_Y>\",\n                    \"type\": \"quantitative\",\n                    \"title\": \"<DVC_METRIC_Y_LABEL>\",\n                    \"scale\": {\n                        \"zero\": false\n                    }\n                },\n                \"color\": {\n                    \"field\": \"rev\",\n                    \"type\": \"nominal\"\n                }\n            },\n            \"layer\": [\n                {\n                    \"mark\": \"point\"\n                },\n                {\n                    \"selection\": {\n                        \"label\": {\n                            \"type\": \"single\",\n                            \"nearest\": true,\n                            \"on\": \"mouseover\",\n                            \"encodings\": [\n                                \"x\"\n                            ],\n                            \"empty\": \"none\",\n                            \"clear\": \"mouseout\"\n                        }\n                    },\n                    \"mark\": \"point\",\n                    \"encoding\": {\n                        \"opacity\": {\n                            \"condition\": {\n                                \"selection\": \"label\",\n                                \"value\": 1\n                            },\n                            \"value\": 0\n                        }\n                    }\n                }\n            ]\n        },\n        {\n            \"transform\": [\n                {\n                    \"filter\": {\n                        \"selection\": \"label\"\n                    }\n                }\n            ],\n            \"layer\": [\n                {\n                    \"encoding\": {\n                        \"text\": {\n                            \"type\": \"quantitative\",\n                            \"field\": \"<DVC_METRIC_Y>\"\n                        },\n                        \"x\": {\n                            \"field\": \"<DVC_METRIC_X>\",\n                            \"type\": \"quantitative\"\n                        },\n                        \"y\": {\n                            \"field\": \"<DVC_METRIC_Y>\",\n                            \"type\": \"quantitative\"\n                        }\n                    },\n                    \"layer\": [\n                        {\n                            \"mark\": {\n                                \"type\": \"text\",\n                                \"align\": \"left\",\n                                \"dx\": 5,\n                                \"dy\": -5\n                            },\n                            \"encoding\": {\n                                \"color\": {\n                                    \"type\": \"nominal\",\n                                    \"field\": \"rev\"\n                                }\n                            }\n                        }\n                    ]\n                }\n            ]\n        }\n    ]\n}\n"
  },
  {
    "path": ".dvc/plots/smooth.json",
    "content": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n    },\n    \"title\": \"<DVC_METRIC_TITLE>\",\n    \"mark\": {\n        \"type\": \"line\"\n    },\n    \"encoding\": {\n        \"x\": {\n            \"field\": \"<DVC_METRIC_X>\",\n            \"type\": \"quantitative\",\n            \"title\": \"<DVC_METRIC_X_LABEL>\"\n        },\n        \"y\": {\n            \"field\": \"<DVC_METRIC_Y>\",\n            \"type\": \"quantitative\",\n            \"title\": \"<DVC_METRIC_Y_LABEL>\",\n            \"scale\": {\n                \"zero\": false\n            }\n        },\n        \"color\": {\n            \"field\": \"rev\",\n            \"type\": \"nominal\"\n        }\n    },\n    \"transform\": [\n        {\n            \"loess\": \"<DVC_METRIC_Y>\",\n            \"on\": \"<DVC_METRIC_X>\",\n            \"groupby\": [\n                \"rev\"\n            ],\n            \"bandwidth\": 0.3\n        }\n    ]\n}\n"
  },
  {
    "path": ".dvcignore",
    "content": "# Add patterns of files dvc should ignore, which could improve\n# the performance. Learn more at\n# https://dvc.org/doc/user-guide/dvcignore\n"
  },
  {
    "path": ".github/workflows/basic.yaml",
    "content": "name: GitHub Actions Basic Flow\non: [push]\njobs:\n  Basic-workflow:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Basic Information\n        run: |\n          echo \"🎬 The job was automatically triggered by a ${{ github.event_name }} event.\"\n          echo \"💻 This job is now running on a ${{ runner.os }} server hosted by GitHub!\"\n          echo \"🎋 Workflow is running on the branch ${{ github.ref }}\"\n      - name: Checking out the repository\n        uses: actions/checkout@v2\n      - name: Information after checking out\n        run: |\n          echo \"💡 The ${{ github.repository }} repository has been cloned to the runner.\"\n          echo \"🖥️ The workflow is now ready to test your code on the runner.\"\n      - name: List files in the repository\n        run: |\n          ls ${{ github.workspace }}\n      - run: echo \"🍏 This job's status is ${{ job.status }}.\""
  },
  {
    "path": ".github/workflows/build_docker_image.yaml",
    "content": "name: Create Docker Container\n\non: [push]\n\njobs:\n  mlops-container:\n    runs-on: ubuntu-latest\n    defaults:\n      run:\n        working-directory: ./week_9_monitoring\n    steps:\n    - name: Checkout\n      uses: actions/checkout@v2\n      with:\n        ref: ${{ github.ref }}\n    - name: Configure AWS Credentials\n      uses: aws-actions/configure-aws-credentials@v1\n      with:\n        aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}\n        aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}\n        aws-region: us-west-2\n    - name: Build container\n      run: |\n        docker build --build-arg AWS_ACCOUNT_ID=${{ secrets.AWS_ACCOUNT_ID }} \\\n                     --build-arg AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }} \\\n                     --build-arg AWS_SECRET_ACCESS_KEY=${{ secrets.AWS_SECRET_ACCESS_KEY }} \\\n                     --tag mlops-basics .\n    - name: Push2ECR\n      id: ecr\n      uses: jwalton/gh-ecr-push@v1\n      with:\n        access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}\n        secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}\n        region: us-west-2\n        image: mlops-basics:latest\n    \n    - name: Update lambda with image\n      run: aws lambda update-function-code --function-name  MLOps-Basics --image-uri 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n.vscode/\n*/logs/*\n*/models/*\n*/wandb/*\n*/outputs/*\n*/multirun/*\n\n.DS_Store\n*/.DS_Store"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 raviraja\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# MLOps-Basics\n\n > There is nothing magic about magic. The magician merely understands something simple which doesn’t appear to be simple or natural to the untrained audience. Once you learn how to hold a card while making your hand look empty, you only need practice before you, too, can “do magic.” – Jeffrey Friedl in the book Mastering Regular Expressions\n\n**Note: Please raise an issue for any suggestions, corrections, and feedback.**\n\nThe goal of the series is to understand the basics of MLOps like model building, monitoring, configurations, testing, packaging, deployment, cicd, etc.\n\n![pl](images/summary.png)\n\n## Week 0: Project Setup\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-project-setup-part1)\n\nThe project I have implemented is a simple classification problem. The scope of this week is to understand the following topics:\n\n- `How to get the data?`\n- `How to process the data?`\n- `How to define dataloaders?`\n- `How to declare the model?`\n- `How to train the model?`\n- `How to do the inference?`\n\n![pl](images/pl.jpeg)\n\nFollowing tech stack is used:\n\n- [Huggingface Datasets](https://github.com/huggingface/datasets)\n- [Huggingface Transformers](https://github.com/huggingface/transformers)\n- [Pytorch Lightning](https://pytorch-lightning.readthedocs.io/)\n\n## Week 1: Model monitoring - Weights and Biases\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-wandb-integration)\n\nTracking all the experiments like tweaking hyper-parameters, trying different models to test their performance and seeing the connection between model and the input data will help in developing a better model.\n\nThe scope of this week is to understand the following topics:\n\n- `How to configure basic logging with W&B?`\n- `How to compute metrics and log them in W&B?`\n- `How to add plots in W&B?`\n- `How to add data samples to W&B?`\n\n![wannb](images/wandb.png)\n\nFollowing tech stack is used:\n\n- [Weights and Biases](https://wandb.ai/site)\n- [torchmetrics](https://torchmetrics.readthedocs.io/)\n\nReferences:\n\n- [Tutorial on Pytorch Lightning + Weights & Bias](https://www.youtube.com/watch?v=hUXQm46TAKc)\n\n- [WandB Documentation](https://docs.wandb.ai/)\n\n## Week 2: Configurations - Hydra\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-hydra-config)\n\nConfiguration management is a necessary for managing complex software systems. Lack of configuration management can cause serious problems with reliability, uptime, and the ability to scale a system.\n\nThe scope of this week is to understand the following topics:\n\n- `Basics of Hydra`\n- `Overridding configurations`\n- `Splitting configuration across multiple files`\n- `Variable Interpolation`\n- `How to run model with different parameter combinations?`\n\n![hydra](images/hydra.png)\n\nFollowing tech stack is used:\n\n- [Hydra](https://hydra.cc/)\n\nReferences\n\n- [Hydra Documentation](https://hydra.cc/docs/intro)\n\n- [Simone Tutorial on Hydra](https://www.sscardapane.it/tutorials/hydra-tutorial/#executing-multiple-runs)\n\n\n## Week 3: Data Version Control - DVC\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-dvc)\n\nClassical code version control systems are not designed to handle large files, which make cloning and storing the history impractical. Which are very common in Machine Learning.\n\nThe scope of this week is to understand the following topics:\n\n- `Basics of DVC`\n- `Initialising DVC`\n- `Configuring Remote Storage`\n- `Saving Model to the Remote Storage`\n- `Versioning the models`\n\n![dvc](images/dvc.png)\n\nFollowing tech stack is used:\n\n- [DVC](https://dvc.org/)\n\nReferences\n\n- [DVC Documentation](https://dvc.org/doc)\n\n- [DVC Tutorial on Versioning data](https://www.youtube.com/watch?v=kLKBcPonMYw)\n\n## Week 4: Model Packaging - ONNX\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-onnx)\n\nWhy do we need model packaging? Models can be built using any machine learning framework available out there (sklearn, tensorflow, pytorch, etc.). We might want to deploy models in different environments like (mobile, web, raspberry pi) or want to run in a different framework (trained in pytorch, inference in tensorflow).\nA common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers will help a lot.\n\nThis is acheived by a community project `ONNX`.\n\nThe scope of this week is to understand the following topics:\n\n- `What is ONNX?`\n\n- `How to convert a trained model to ONNX format?`\n\n- `What is ONNX Runtime?`\n\n- `How to run ONNX converted model in ONNX Runtime?`\n\n- `Comparisions`\n\n![ONNX](images/onnx.jpeg)\n\nFollowing tech stack is used:\n\n- [ONNX](https://onnx.ai/)\n- [ONNXRuntime](https://www.onnxruntime.ai/)\n\nReferences\n\n- [Abhishek Thakur tutorial on onnx model conversion](https://www.youtube.com/watch?v=7nutT3Aacyw)\n- [Pytorch Lightning documentation on onnx conversion](https://pytorch-lightning.readthedocs.io/en/stable/common/production_inference.html)\n- [Huggingface Blog on ONNXRuntime](https://medium.com/microsoftazure/accelerate-your-nlp-pipelines-using-hugging-face-transformers-and-onnx-runtime-2443578f4333)\n- [Piotr Blog on onnx conversion](https://tugot17.github.io/data-science-blog/onnx/tutorial/2020/09/21/Exporting-lightning-model-to-onnx.html)\n\n\n## Week 5: Model Packaging - Docker\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-docker)\n\nWhy do we need packaging? We might have to share our application with others, and when they try to run the application most of the time it doesn’t run due to dependencies issues / OS related issues and for that, we say (famous quote across engineers) that `It works on my laptop/system`.\n\nSo for others to run the applications they have to set up the same environment as it was run on the host side which means a lot of manual configuration and installation of components.\n\nThe solution to these limitations is a technology called Containers.\n\nBy containerizing/packaging the application, we can run the application on any cloud platform to get advantages of managed services and autoscaling and reliability, and many more.\n\nThe most prominent tool to do the packaging of application is Docker 🛳\n\nThe scope of this week is to understand the following topics:\n\n- `FastAPI wrapper`\n- `Basics of Docker`\n- `Building Docker Container`\n- `Docker Compose`\n\n![Docker](images/docker_flow.png)\n\nReferences\n\n- [Analytics vidhya blog](https://www.analyticsvidhya.com/blog/2021/06/a-hands-on-guide-to-containerized-your-machine-learning-workflow-with-docker/)\n\n\n## Week 6: CI/CD - GitHub Actions\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-github-actions)\n\nCI/CD is a coding philosophy and set of practices with which you can continuously build, test, and deploy iterative code changes.\n\nThis iterative process helps reduce the chance that you develop new code based on a buggy or failed previous versions. With this method, you strive to have less human intervention or even no intervention at all, from the development of new code until its deployment.\n\nIn this post, I will be going through the following topics:\n\n- Basics of GitHub Actions\n- First GitHub Action\n- Creating Google Service Account\n- Giving access to Service account\n- Configuring DVC to use Google Service account\n- Configuring Github Action\n\n![Docker](images/basic_flow.png)\n\nReferences\n\n- [Configuring service account](https://dvc.org/doc/user-guide/setup-google-drive-remote)\n\n- [Github actions](https://docs.github.com/en/actions/quickstart)\n\n\n## Week 7: Container Registry - AWS ECR\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-container-registry)\n\nA container registry is a place to store container images. A container image is a file comprised of multiple layers which can execute applications in a single instance. Hosting all the images in one stored location allows users to commit, identify and pull images when needed.\n\nAmazon Simple Storage Service (S3) is a storage for the internet. It is designed for large-capacity, low-cost storage provision across multiple geographical regions.\n\nIn this week, I will be going through the following topics:\n\n- `Basics of S3`\n\n- `Programmatic access to S3`\n\n- `Configuring AWS S3 as remote storage in DVC`\n\n- `Basics of ECR`\n\n- `Configuring GitHub Actions to use S3, ECR`\n\n![Docker](images/ecr_flow.png)\n\n\n## Week 8: Serverless Deployment - AWS Lambda\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-serverless)\n\nA serverless architecture is a way to build and run applications and services without having to manage infrastructure. The application still runs on servers, but all the server management is done by third party service (AWS). We no longer have to provision, scale, and maintain servers to run the applications. By using a serverless architecture, developers can focus on their core product instead of worrying about managing and operating servers or runtimes, either in the cloud or on-premises.\n\nIn this week, I will be going through the following topics:\n\n- `Basics of Serverless`\n\n- `Basics of AWS Lambda`\n\n- `Triggering Lambda with API Gateway`\n\n- `Deploying Container using Lambda`\n\n- `Automating deployment to Lambda using Github Actions`\n\n![Docker](images/lambda_flow.png)\n\n\n## Week 9: Prediction Monitoring - Kibana\n\n<img src=\"https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange\"/>\n\nRefer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-monitoring)\n\n\nMonitoring systems can help give us confidence that our systems are running smoothly and, in the event of a system failure, can quickly provide appropriate context when diagnosing the root cause.\n\nThings we want to monitor during and training and inference are different. During training we are concered about whether the loss is decreasing or not, whether the model is overfitting, etc.\n\nBut, during inference, We like to have confidence that our model is making correct predictions.\n\nThere are many reasons why a model can fail to make useful predictions:\n\n- The underlying data distribution has shifted over time and the model has gone stale. i.e inference data characteristics is different from the data characteristics used to train the model.\n\n- The inference data stream contains edge cases (not seen during model training). In this scenarios model might perform poorly or can lead to errors.\n\n- The model was misconfigured in its production deployment. (Configuration issues are common)\n\nIn all of these scenarios, the model could still make a `successful` prediction from a service perspective, but the predictions will likely not be useful. Monitoring machine learning models can help us detect such scenarios and intervene (e.g. trigger a model retraining/deployment pipeline).\n\nIn this week, I will be going through the following topics:\n\n- `Basics of Cloudwatch Logs`\n\n- `Creating Elastic Search Cluster`\n\n- `Configuring Cloudwatch Logs with Elastic Search`\n\n- `Creating Index Patterns in Kibana`\n\n- `Creating Kibana Visualisations`\n\n- `Creating Kibana Dashboard`\n\n![Docker](images/kibana_flow.png)\n"
  },
  {
    "path": "week_0_project_setup/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Inference\n\nAfter training, update the model checkpoint path in the code and run\n\n```\npython inference.py\n```\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks. \n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```\n\n\n"
  },
  {
    "path": "week_0_project_setup/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", batch_size=32):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=512,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_0_project_setup/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_0_project_setup/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=0)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/epoch=0-step=267.ckpt\")\n    print(predictor.predict(sentence))\n"
  },
  {
    "path": "week_0_project_setup/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom transformers import AutoModel\nfrom sklearn.metrics import accuracy_score\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=1e-2):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModel.from_pretrained(model_name)\n        self.W = nn.Linear(self.bert.config.hidden_size, 2)\n        self.num_classes = 2\n\n    def forward(self, input_ids, attention_mask):\n        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n\n        h_cls = outputs.last_hidden_state[:, 0]\n        logits = self.W(h_cls)\n        return logits\n\n    def training_step(self, batch, batch_idx):\n        logits = self.forward(batch[\"input_ids\"], batch[\"attention_mask\"])\n        loss = F.cross_entropy(logits, batch[\"label\"])\n        self.log(\"train_loss\", loss, prog_bar=True)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        logits = self.forward(batch[\"input_ids\"], batch[\"attention_mask\"])\n        loss = F.cross_entropy(logits, batch[\"label\"])\n        _, preds = torch.max(logits, dim=1)\n        val_acc = accuracy_score(preds.cpu(), batch[\"label\"].cpu())\n        val_acc = torch.tensor(val_acc)\n        self.log(\"val_loss\", loss, prog_bar=True)\n        self.log(\"val_acc\", val_acc, prog_bar=True)\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_0_project_setup/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2"
  },
  {
    "path": "week_0_project_setup/train.py",
    "content": "import torch\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\n\nfrom data import DataModule\nfrom model import ColaModel\n\n\ndef main():\n    cola_data = DataModule()\n    cola_model = ColaModel()\n\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=\"./models\", monitor=\"val_loss\", mode=\"min\"\n    )\n    early_stopping_callback = EarlyStopping(\n        monitor=\"val_loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    trainer = pl.Trainer(\n        default_root_dir=\"logs\",\n        gpus=(1 if torch.cuda.is_available() else 0),\n        max_epochs=5,\n        fast_dev_run=False,\n        logger=pl.loggers.TensorBoardLogger(\"logs/\", name=\"cola\", version=1),\n        callbacks=[checkpoint_callback, early_stopping_callback],\n    )\n    trainer.fit(cola_model, cola_data)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_1_wandb_logging/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb: \nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Inference\n\nAfter training, update the model checkpoint path in the code and run\n\n```\npython inference.py\n```\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks. \n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_1_wandb_logging/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", batch_size=64):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=128,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_1_wandb_logging/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_1_wandb_logging/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=0)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/epoch=0-step=267.ckpt\")\n    print(predictor.predict(sentence))\n"
  },
  {
    "path": "week_1_wandb_logging/model.py",
    "content": "import torch\nimport wandb\nimport numpy as np\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nimport torchmetrics\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_1_wandb_logging/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_1_wandb_logging/train.py",
    "content": "import torch\nimport wandb\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\ndef main():\n    cola_data = DataModule()\n    cola_model = ColaModel()\n\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=\"./models\",\n        filename=\"best-checkpoint.ckpt\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=1,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=10,\n        deterministic=True,\n        # limit_train_batches=0.25,\n        # limit_val_batches=0.25\n    )\n    trainer.fit(cola_model, cola_data)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_2_hydra_config/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb: \nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Inference\n\nAfter training, update the model checkpoint path in the code and run\n\n```\npython inference.py\n```\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks. \n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_2_hydra_config/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_2_hydra_config/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_2_hydra_config/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_2_hydra_config/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_2_hydra_config/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_2_hydra_config/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_2_hydra_config/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=0)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n"
  },
  {
    "path": "week_2_hydra_config/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_2_hydra_config/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog"
  },
  {
    "path": "week_2_hydra_config/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=\"./models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        limit_train_batches=cfg.training.limit_train_batches,\n        limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_3_dvc/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb:\nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Inference\n\nAfter training, update the model checkpoint path in the code and run\n\n```\npython inference.py\n```\n\n### Versioning data\n\nRefer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.\n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_3_dvc/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_3_dvc/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_3_dvc/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_3_dvc/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_3_dvc/data.py",
    "content": "import torch\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_3_dvc/dvcfiles/trained_model.dvc",
    "content": "wdir: ../models\nouts:\n- md5: c2f5c0a1954209865b9be1945f33ed6e\n  size: 17567709\n  path: best-checkpoint.ckpt\n"
  },
  {
    "path": "week_3_dvc/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_3_dvc/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=0)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n"
  },
  {
    "path": "week_3_dvc/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_3_dvc/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog"
  },
  {
    "path": "week_3_dvc/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    root_dir = hydra.utils.get_original_cwd()\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=f\"{root_dir}/models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        # limit_train_batches=cfg.training.limit_train_batches,\n        # limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_4_onnx/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb:\nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Versioning data\n\nRefer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)\n\n### Exporting model to ONNX\n\nOnce the model is trained, convert the model using the following command:\n\n```\npython convert_model_to_onnx.py\n```\n\n### Inference\n\n#### Inference using standard pytorch\n\n```\npython inference.py\n```\n\n#### Inference using ONNX Runtime\n\n```\npython inference_onnx.py\n```\n\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.\n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_4_onnx/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_4_onnx/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_4_onnx/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_4_onnx/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_4_onnx/convert_model_to_onnx.py",
    "content": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom data import DataModule\n\nlogger = logging.getLogger(__name__)\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef convert_model(cfg):\n    root_dir = hydra.utils.get_original_cwd()\n    model_path = f\"{root_dir}/models/best-checkpoint.ckpt\"\n    logger.info(f\"Loading pre-trained model from: {model_path}\")\n    cola_model = ColaModel.load_from_checkpoint(model_path)\n\n    data_model = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    data_model.prepare_data()\n    data_model.setup()\n    input_batch = next(iter(data_model.train_dataloader()))\n    input_sample = {\n        \"input_ids\": input_batch[\"input_ids\"][0].unsqueeze(0),\n        \"attention_mask\": input_batch[\"attention_mask\"][0].unsqueeze(0),\n    }\n\n    # Export the model\n    logger.info(f\"Converting the model into ONNX format\")\n    torch.onnx.export(\n        cola_model,  # model being run\n        (\n            input_sample[\"input_ids\"],\n            input_sample[\"attention_mask\"],\n        ),  # model input (or a tuple for multiple inputs)\n        f\"{root_dir}/models/model.onnx\",  # where to save the model (can be a file or file-like object)\n        export_params=True,\n        opset_version=10,\n        input_names=[\"input_ids\", \"attention_mask\"],  # the model's input names\n        output_names=[\"output\"],  # the model's output names\n        dynamic_axes={\n            \"input_ids\": {0: \"batch_size\"},  # variable length axes\n            \"attention_mask\": {0: \"batch_size\"},\n            \"output\": {0: \"batch_size\"},\n        },\n    )\n\n    logger.info(\n        f\"Model converted successfully. ONNX format model is at: {root_dir}/models/model.onnx\"\n    )\n\n\nif __name__ == \"__main__\":\n    convert_model()\n"
  },
  {
    "path": "week_4_onnx/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_4_onnx/dvcfiles/trained_model.dvc",
    "content": "wdir: ../models\nouts:\n- md5: c2f5c0a1954209865b9be1945f33ed6e\n  size: 17567709\n  path: best-checkpoint.ckpt\n"
  },
  {
    "path": "week_4_onnx/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_4_onnx/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=1)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_4_onnx/inference_onnx.py",
    "content": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaONNXPredictor:\n    def __init__(self, model_path):\n        self.ort_session = ort.InferenceSession(model_path)\n        self.processor = DataModule()\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n\n        ort_inputs = {\n            \"input_ids\": np.expand_dims(processed[\"input_ids\"], axis=0),\n            \"attention_mask\": np.expand_dims(processed[\"attention_mask\"], axis=0),\n        }\n        ort_outs = self.ort_session.run(None, ort_inputs)\n        scores = softmax(ort_outs[0])[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaONNXPredictor(\"./models/model.onnx\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_4_onnx/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_4_onnx/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog"
  },
  {
    "path": "week_4_onnx/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    root_dir = hydra.utils.get_original_cwd()\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=f\"{root_dir}/models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        # limit_train_batches=cfg.training.limit_train_batches,\n        # limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_4_onnx/utils.py",
    "content": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n    def function(a):\n        pass\n    \"\"\"\n\n    @wraps(f)\n    def wrapper(*args, **kwargs):\n        start = time.time()\n        result = f(*args, **kwargs)\n        end = time.time()\n        print(\"function:%r took: %2.5f sec\" % (f.__name__, end - start))\n        return result\n\n    return wrapper\n"
  },
  {
    "path": "week_5_docker/Dockerfile",
    "content": "FROM huggingface/transformers-pytorch-cpu:latest\nCOPY ./ /app\nWORKDIR /app\nRUN pip install -r requirements_prod.txt\nENV LC_ALL=C.UTF-8\nENV LANG=C.UTF-8\nEXPOSE 8000\nCMD [\"uvicorn\", \"app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"]\n"
  },
  {
    "path": "week_5_docker/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb:\nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Versioning data\n\nRefer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)\n\n### Exporting model to ONNX\n\nOnce the model is trained, convert the model using the following command:\n\n```\npython convert_model_to_onnx.py\n```\n\n### Inference\n\n#### Inference using standard pytorch\n\n```\npython inference.py\n```\n\n#### Inference using ONNX Runtime\n\n```\npython inference_onnx.py\n```\n\n### Docker\n\nInstall the docker using the [instructions here](https://docs.docker.com/engine/install/)\n\nBuild the image using the command\n\n```shell\ndocker build -t inference:latest .\n```\n\nThen run the container using the command\n\n```shell\ndocker run -p 8000:8000 --name inference_container inference:latest\n```\n\n(or)\n\nBuild and run the container using the command\n\n```shell\ndocker-compose up\n```\n\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.\n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_5_docker/app.py",
    "content": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredictor = ColaONNXPredictor(\"./models/model.onnx\")\n\n@app.get(\"/\")\nasync def home_page():\n    return \"<h2>Sample prediction API</h2>\"\n\n\n@app.get(\"/predict\")\nasync def get_prediction(text: str):\n    result =  predictor.predict(text)\n    return result"
  },
  {
    "path": "week_5_docker/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_5_docker/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_5_docker/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_5_docker/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_5_docker/convert_model_to_onnx.py",
    "content": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom data import DataModule\n\nlogger = logging.getLogger(__name__)\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef convert_model(cfg):\n    root_dir = hydra.utils.get_original_cwd()\n    model_path = f\"{root_dir}/models/best-checkpoint.ckpt\"\n    logger.info(f\"Loading pre-trained model from: {model_path}\")\n    cola_model = ColaModel.load_from_checkpoint(model_path)\n\n    data_model = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    data_model.prepare_data()\n    data_model.setup()\n    input_batch = next(iter(data_model.train_dataloader()))\n    input_sample = {\n        \"input_ids\": input_batch[\"input_ids\"][0].unsqueeze(0),\n        \"attention_mask\": input_batch[\"attention_mask\"][0].unsqueeze(0),\n    }\n\n    # Export the model\n    logger.info(f\"Converting the model into ONNX format\")\n    torch.onnx.export(\n        cola_model,  # model being run\n        (\n            input_sample[\"input_ids\"],\n            input_sample[\"attention_mask\"],\n        ),  # model input (or a tuple for multiple inputs)\n        f\"{root_dir}/models/model.onnx\",  # where to save the model (can be a file or file-like object)\n        export_params=True,\n        opset_version=10,\n        input_names=[\"input_ids\", \"attention_mask\"],  # the model's input names\n        output_names=[\"output\"],  # the model's output names\n        dynamic_axes={\n            \"input_ids\": {0: \"batch_size\"},  # variable length axes\n            \"attention_mask\": {0: \"batch_size\"},\n            \"output\": {0: \"batch_size\"},\n        },\n    )\n\n    logger.info(\n        f\"Model converted successfully. ONNX format model is at: {root_dir}/models/model.onnx\"\n    )\n\n\nif __name__ == \"__main__\":\n    convert_model()\n"
  },
  {
    "path": "week_5_docker/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_5_docker/docker-compose.yml",
    "content": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:\n            - \"8000:8000\""
  },
  {
    "path": "week_5_docker/dvcfiles/trained_model.dvc",
    "content": "wdir: ../models\nouts:\n- md5: c2f5c0a1954209865b9be1945f33ed6e\n  size: 17567709\n  path: best-checkpoint.ckpt\n"
  },
  {
    "path": "week_5_docker/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_5_docker/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=1)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_5_docker/inference_onnx.py",
    "content": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaONNXPredictor:\n    def __init__(self, model_path):\n        self.ort_session = ort.InferenceSession(model_path)\n        self.processor = DataModule()\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n\n        ort_inputs = {\n            \"input_ids\": np.expand_dims(processed[\"input_ids\"], axis=0),\n            \"attention_mask\": np.expand_dims(processed[\"attention_mask\"], axis=0),\n        }\n        ort_outs = self.ort_session.run(None, ort_inputs)\n        scores = softmax(ort_outs[0])[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": float(score)})\n        print(predictions)\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaONNXPredictor(\"./models/model.onnx\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_5_docker/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_5_docker/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog\nfastapi\nuvicorn\n"
  },
  {
    "path": "week_5_docker/requirements_inference.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nuvicorn\n"
  },
  {
    "path": "week_5_docker/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    root_dir = hydra.utils.get_original_cwd()\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=f\"{root_dir}/models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        # limit_train_batches=cfg.training.limit_train_batches,\n        # limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_5_docker/utils.py",
    "content": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n    def function(a):\n        pass\n    \"\"\"\n\n    @wraps(f)\n    def wrapper(*args, **kwargs):\n        start = time.time()\n        result = f(*args, **kwargs)\n        end = time.time()\n        print(\"function:%r took: %2.5f sec\" % (f.__name__, end - start))\n        return result\n\n    return wrapper\n"
  },
  {
    "path": "week_6_github_actions/Dockerfile",
    "content": "FROM huggingface/transformers-pytorch-cpu:latest\n\nCOPY ./ /app\nWORKDIR /app\n\n# install requirements\nRUN pip install \"dvc[gdrive]\"\nRUN pip install -r requirements_inference.txt\n\n# initialise dvc\nRUN dvc init --no-scm\n# configuring remote server in dvc\nRUN dvc remote add -d storage gdrive://19JK5AFbqOBlrFVwDHjTrf9uvQFtS0954\nRUN dvc remote modify storage gdrive_use_service_account true\nRUN dvc remote modify storage gdrive_service_account_json_file_path creds.json\n\nRUN cat .dvc/config\n# pulling the trained model\nRUN dvc pull dvcfiles/trained_model.dvc\n\nENV LC_ALL=C.UTF-8\nENV LANG=C.UTF-8\n\n# running the application\nEXPOSE 8000\nCMD [\"uvicorn\", \"app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"]\n"
  },
  {
    "path": "week_6_github_actions/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb:\nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Versioning data\n\nRefer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)\n\n### Exporting model to ONNX\n\nOnce the model is trained, convert the model using the following command:\n\n```\npython convert_model_to_onnx.py\n```\n\n### Inference\n\n#### Inference using standard pytorch\n\n```\npython inference.py\n```\n\n#### Inference using ONNX Runtime\n\n```\npython inference_onnx.py\n```\n\n### Google Service account\n\nCreate service account using the steps mentioned here: [Create service account](https://www.ravirajag.dev/blog/mlops-github-actions)\n\n### Configuring dvc\n\n```\ndvc init\ndvc remote add -d storage gdrive://19JK5AFbqOBlrFVwDHjTrf9uvQFtS0954\ndvc remote modify storage gdrive_use_service_account true\ndvc remote modify storage gdrive_service_account_json_file_path creds.json\n```\n\n`creds.json` is the file created during service account creation\n\n\n### Docker\n\nInstall the docker using the [instructions here](https://docs.docker.com/engine/install/)\n\nBuild the image using the command\n\n```shell\ndocker build -t inference:latest .\n```\n\nThen run the container using the command\n\n```shell\ndocker run -p 8000:8000 --name inference_container inference:latest\n```\n\n(or)\n\nBuild and run the container using the command\n\n```shell\ndocker-compose up\n```\n\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.\n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_6_github_actions/app.py",
    "content": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredictor = ColaONNXPredictor(\"./models/model.onnx\")\n\n@app.get(\"/\")\nasync def home_page():\n    return \"<h2>Sample prediction API</h2>\"\n\n\n@app.get(\"/predict\")\nasync def get_prediction(text: str):\n    result =  predictor.predict(text)\n    return result"
  },
  {
    "path": "week_6_github_actions/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_6_github_actions/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_6_github_actions/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_6_github_actions/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_6_github_actions/convert_model_to_onnx.py",
    "content": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom data import DataModule\n\nlogger = logging.getLogger(__name__)\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef convert_model(cfg):\n    root_dir = hydra.utils.get_original_cwd()\n    model_path = f\"{root_dir}/models/best-checkpoint.ckpt\"\n    logger.info(f\"Loading pre-trained model from: {model_path}\")\n    cola_model = ColaModel.load_from_checkpoint(model_path)\n\n    data_model = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    data_model.prepare_data()\n    data_model.setup()\n    input_batch = next(iter(data_model.train_dataloader()))\n    input_sample = {\n        \"input_ids\": input_batch[\"input_ids\"][0].unsqueeze(0),\n        \"attention_mask\": input_batch[\"attention_mask\"][0].unsqueeze(0),\n    }\n\n    # Export the model\n    logger.info(f\"Converting the model into ONNX format\")\n    torch.onnx.export(\n        cola_model,  # model being run\n        (\n            input_sample[\"input_ids\"],\n            input_sample[\"attention_mask\"],\n        ),  # model input (or a tuple for multiple inputs)\n        f\"{root_dir}/models/model.onnx\",  # where to save the model (can be a file or file-like object)\n        export_params=True,\n        opset_version=10,\n        input_names=[\"input_ids\", \"attention_mask\"],  # the model's input names\n        output_names=[\"output\"],  # the model's output names\n        dynamic_axes={\n            \"input_ids\": {0: \"batch_size\"},  # variable length axes\n            \"attention_mask\": {0: \"batch_size\"},\n            \"output\": {0: \"batch_size\"},\n        },\n    )\n\n    logger.info(\n        f\"Model converted successfully. ONNX format model is at: {root_dir}/models/model.onnx\"\n    )\n\n\nif __name__ == \"__main__\":\n    convert_model()\n"
  },
  {
    "path": "week_6_github_actions/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_6_github_actions/docker-compose.yml",
    "content": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:\n            - \"8000:8000\""
  },
  {
    "path": "week_6_github_actions/dvcfiles/trained_model.dvc",
    "content": "wdir: ../models\nouts:\n- md5: d82b8390fa2f09b121de4abfa094a7a9\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_6_github_actions/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_6_github_actions/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=1)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_6_github_actions/inference_onnx.py",
    "content": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaONNXPredictor:\n    def __init__(self, model_path):\n        self.ort_session = ort.InferenceSession(model_path)\n        self.processor = DataModule()\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n\n        ort_inputs = {\n            \"input_ids\": np.expand_dims(processed[\"input_ids\"], axis=0),\n            \"attention_mask\": np.expand_dims(processed[\"attention_mask\"], axis=0),\n        }\n        ort_outs = self.ort_session.run(None, ort_inputs)\n        scores = softmax(ort_outs[0])[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": float(score)})\n        print(predictions)\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaONNXPredictor(\"./models/model.onnx\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_6_github_actions/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_6_github_actions/parse_json.py",
    "content": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print(data)\ndata = eval(data)\nprint(data)\n\nwith open('test.json', 'w') as f:\n\tjson.dump(data, f)\n"
  },
  {
    "path": "week_6_github_actions/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog\nfastapi\nuvicorn\n"
  },
  {
    "path": "week_6_github_actions/requirements_inference.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nuvicorn\ndvc"
  },
  {
    "path": "week_6_github_actions/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    root_dir = hydra.utils.get_original_cwd()\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=f\"{root_dir}/models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        # limit_train_batches=cfg.training.limit_train_batches,\n        # limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_6_github_actions/utils.py",
    "content": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n    def function(a):\n        pass\n    \"\"\"\n\n    @wraps(f)\n    def wrapper(*args, **kwargs):\n        start = time.time()\n        result = f(*args, **kwargs)\n        end = time.time()\n        print(\"function:%r took: %2.5f sec\" % (f.__name__, end - start))\n        return result\n\n    return wrapper\n"
  },
  {
    "path": "week_7_ecr/Dockerfile",
    "content": "FROM huggingface/transformers-pytorch-cpu:latest\n\nCOPY ./ /app\nWORKDIR /app\n\nARG AWS_ACCESS_KEY_ID\nARG AWS_SECRET_ACCESS_KEY\n\n\n#this envs are experimental\nENV AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \\\n    AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY\n\n\n# install requirements\nRUN pip install \"dvc[s3]\"\nRUN pip install -r requirements_inference.txt\n\n# initialise dvc\nRUN dvc init --no-scm\n# configuring remote server in dvc\nRUN dvc remote add -d model-store s3://models-dvc/trained_models/\n\nRUN cat .dvc/config\n# pulling the trained model\nRUN dvc pull dvcfiles/trained_model.dvc\n\nENV LC_ALL=C.UTF-8\nENV LANG=C.UTF-8\n\n# running the application\nEXPOSE 8000\nCMD [\"uvicorn\", \"app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"]\n"
  },
  {
    "path": "week_7_ecr/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb:\nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Versioning data\n\nRefer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)\n\n### Exporting model to ONNX\n\nOnce the model is trained, convert the model using the following command:\n\n```\npython convert_model_to_onnx.py\n```\n\n### Inference\n\n#### Inference using standard pytorch\n\n```\npython inference.py\n```\n\n#### Inference using ONNX Runtime\n\n```\npython inference_onnx.py\n```\n\n## S3 & ECR\n\nFollow the instructions mentioned in the [blog post](https://www.ravirajag.dev/blog/mlops-container-registry) for creating S3 bucket and ECR repository. \n\n### Configuring dvc\n\n```\ndvc init (this has to be done at root folder)\ndvc remote add -d model-store s3://models-dvc/trained_models/\n```\n\n### AWS credentials\n\nCreate the credentials as mentioned in the [blog post](https://www.ravirajag.dev/blog/mlops-container-registry)\n\n**Do not share the secrets with others**\n\nSet the ACCESS key and id values in environment variables.\n\n```\nexport AWS_ACCESS_KEY_ID=<ACCESS KEY ID>\nexport AWS_SECRET_ACCESS_KEY=<ACCESS SECRET>\n```\n\n### Trained model in DVC\n\nSdd the trained model(onnx) to dvc using the following command:\n\n```shell\ncd dvcfiles\ndvc add ../models/model.onnx --file trained_model.dvc\n```\n\nPush the model to remote storage\n\n```shell\ndvc push trained_model.dvc\n```\n\n### Docker\n\nInstall the docker using the [instructions here](https://docs.docker.com/engine/install/)\n\nBuild the image using the command\n\n```shell\ndocker build -t mlops-basics:latest .\n```\n\nThen run the container using the command\n\n```shell\ndocker run -p 8000:8000 --name inference_container mlops-basics:latest\n```\n\n(or)\n\nBuild and run the container using the command\n\n```shell\ndocker-compose up\n```\n\n### Pushing the image to ECR\n\nFollow the instructions mentioned in [blog post](https://www.ravirajag.dev/blog/mlops-container-registry) for creating ECR repository.\n\n- Authenticating docker client to ECR\n\n```\naws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 246113150184.dkr.ecr.us-west-2.amazonaws.com\n```\n\n- Tagging the image\n\n```\ndocker tag mlops-basics:latest 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest\n```\n\n- Pushing the image\n\n```\ndocker push 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest\n```\n\nRefer to `.github/workflows/build_docker_image.yaml` file for automatically creating the docker image with trained model and pushing it to ECR.\n\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.\n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_7_ecr/app.py",
    "content": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredictor = ColaONNXPredictor(\"./models/model.onnx\")\n\n@app.get(\"/\")\nasync def home_page():\n    return \"<h2>Sample prediction API</h2>\"\n\n\n@app.get(\"/predict\")\nasync def get_prediction(text: str):\n    result =  predictor.predict(text)\n    return result"
  },
  {
    "path": "week_7_ecr/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_7_ecr/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_7_ecr/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_7_ecr/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_7_ecr/convert_model_to_onnx.py",
    "content": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom data import DataModule\n\nlogger = logging.getLogger(__name__)\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef convert_model(cfg):\n    root_dir = hydra.utils.get_original_cwd()\n    model_path = f\"{root_dir}/models/best-checkpoint.ckpt\"\n    logger.info(f\"Loading pre-trained model from: {model_path}\")\n    cola_model = ColaModel.load_from_checkpoint(model_path)\n\n    data_model = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    data_model.prepare_data()\n    data_model.setup()\n    input_batch = next(iter(data_model.train_dataloader()))\n    input_sample = {\n        \"input_ids\": input_batch[\"input_ids\"][0].unsqueeze(0),\n        \"attention_mask\": input_batch[\"attention_mask\"][0].unsqueeze(0),\n    }\n\n    # Export the model\n    logger.info(f\"Converting the model into ONNX format\")\n    torch.onnx.export(\n        cola_model,  # model being run\n        (\n            input_sample[\"input_ids\"],\n            input_sample[\"attention_mask\"],\n        ),  # model input (or a tuple for multiple inputs)\n        f\"{root_dir}/models/model.onnx\",  # where to save the model (can be a file or file-like object)\n        export_params=True,\n        opset_version=10,\n        input_names=[\"input_ids\", \"attention_mask\"],  # the model's input names\n        output_names=[\"output\"],  # the model's output names\n        dynamic_axes={\n            \"input_ids\": {0: \"batch_size\"},  # variable length axes\n            \"attention_mask\": {0: \"batch_size\"},\n            \"output\": {0: \"batch_size\"},\n        },\n    )\n\n    logger.info(\n        f\"Model converted successfully. ONNX format model is at: {root_dir}/models/model.onnx\"\n    )\n\n\nif __name__ == \"__main__\":\n    convert_model()\n"
  },
  {
    "path": "week_7_ecr/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_7_ecr/docker-compose.yml",
    "content": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:\n            - \"8000:8000\""
  },
  {
    "path": "week_7_ecr/dvcfiles/trained_model.dvc",
    "content": "wdir: ../models\nouts:\n- md5: 02f3b0034769ba45d758ad1bb9de33a3\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_7_ecr/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_7_ecr/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=1)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_7_ecr/inference_onnx.py",
    "content": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaONNXPredictor:\n    def __init__(self, model_path):\n        self.ort_session = ort.InferenceSession(model_path)\n        self.processor = DataModule()\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n\n        ort_inputs = {\n            \"input_ids\": np.expand_dims(processed[\"input_ids\"], axis=0),\n            \"attention_mask\": np.expand_dims(processed[\"attention_mask\"], axis=0),\n        }\n        ort_outs = self.ort_session.run(None, ort_inputs)\n        scores = softmax(ort_outs[0])[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": float(score)})\n        print(predictions)\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaONNXPredictor(\"./models/model.onnx\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_7_ecr/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_7_ecr/parse_json.py",
    "content": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print(data)\ndata = eval(data)\nprint(data)\n\nwith open('test.json', 'w') as f:\n\tjson.dump(data, f)\n"
  },
  {
    "path": "week_7_ecr/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog\nfastapi\nuvicorn\n"
  },
  {
    "path": "week_7_ecr/requirements_inference.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nuvicorn\ndvc"
  },
  {
    "path": "week_7_ecr/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    root_dir = hydra.utils.get_original_cwd()\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=f\"{root_dir}/models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        # limit_train_batches=cfg.training.limit_train_batches,\n        # limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_7_ecr/utils.py",
    "content": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n    def function(a):\n        pass\n    \"\"\"\n\n    @wraps(f)\n    def wrapper(*args, **kwargs):\n        start = time.time()\n        result = f(*args, **kwargs)\n        end = time.time()\n        print(\"function:%r took: %2.5f sec\" % (f.__name__, end - start))\n        return result\n\n    return wrapper\n"
  },
  {
    "path": "week_8_serverless/Dockerfile",
    "content": "FROM amazon/aws-lambda-python\n\nARG AWS_ACCESS_KEY_ID\nARG AWS_SECRET_ACCESS_KEY\nARG MODEL_DIR=./models\nRUN mkdir $MODEL_DIR\n\nENV TRANSFORMERS_CACHE=$MODEL_DIR \\\n    TRANSFORMERS_VERBOSITY=error\n\nENV AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \\\n    AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY\n\nRUN yum install git -y && yum -y install gcc-c++\nCOPY requirements_inference.txt requirements_inference.txt\nRUN pip install -r requirements_inference.txt --no-cache-dir\nCOPY ./ ./\nENV PYTHONPATH \"${PYTHONPATH}:./\"\nENV LC_ALL=C.UTF-8\nENV LANG=C.UTF-8\nRUN pip install \"dvc[s3]\"\n# configuring remote server in dvc\nRUN dvc init --no-scm\nRUN dvc remote add -d model-store s3://models-dvc/trained_models/\n\n# pulling the trained model\nRUN dvc pull dvcfiles/trained_model.dvc\nRUN ls\nRUN python lambda_handler.py\nRUN chmod -R 0755 $MODEL_DIR\nCMD [ \"lambda_handler.lambda_handler\"]"
  },
  {
    "path": "week_8_serverless/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb:\nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Versioning data\n\nRefer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)\n\n### Exporting model to ONNX\n\nOnce the model is trained, convert the model using the following command:\n\n```\npython convert_model_to_onnx.py\n```\n\n### Inference\n\n#### Inference using standard pytorch\n\n```\npython inference.py\n```\n\n#### Inference using ONNX Runtime\n\n```\npython inference_onnx.py\n```\n\n## S3 & ECR\n\nFollow the instructions mentioned in the [blog post](https://www.ravirajag.dev/blog/mlops-container-registry) for creating S3 bucket and ECR repository. \n\n### Configuring dvc\n\n```\ndvc init (this has to be done at root folder)\ndvc remote add -d model-store s3://models-dvc/trained_models/\n```\n\n### AWS credentials\n\nCreate the credentials as mentioned in the [blog post](https://www.ravirajag.dev/blog/mlops-container-registry)\n\n**Do not share the secrets with others**\n\nSet the ACCESS key and id values in environment variables.\n\n```\nexport AWS_ACCESS_KEY_ID=<ACCESS KEY ID>\nexport AWS_SECRET_ACCESS_KEY=<ACCESS SECRET>\n```\n\n### Trained model in DVC\n\nSdd the trained model(onnx) to dvc using the following command:\n\n```shell\ncd dvcfiles\ndvc add ../models/model.onnx --file trained_model.dvc\n```\n\nPush the model to remote storage\n\n```shell\ndvc push trained_model.dvc\n```\n\n### Docker\n\nInstall the docker using the [instructions here](https://docs.docker.com/engine/install/)\n\nBuild the image using the command\n\n```shell\ndocker build -t mlops-basics:latest .\n```\n\n**The default command in dockerfile is modified to support the lambda. If you want to run without lambda use the last weeks dockerfile.**\n\nThen run the container using the command\n\n```shell\ndocker run -p 8000:8000 --name inference_container mlops-basics:latest\n```\n\n(or)\n\nBuild and run the container using the command\n\n```shell\ndocker-compose up\n```\n\n### Pushing the image to ECR\n\nFollow the instructions mentioned in [blog post](https://www.ravirajag.dev/blog/mlops-container-registry) for creating ECR repository.\n\n- Authenticating docker client to ECR\n\n```\naws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 246113150184.dkr.ecr.us-west-2.amazonaws.com\n```\n\n- Tagging the image\n\n```\ndocker tag mlops-basics:latest 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest\n```\n\n- Pushing the image\n\n```\ndocker push 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest\n```\n\nRefer to `.github/workflows/build_docker_image.yaml` file for automatically creating the docker image with trained model and pushing it to ECR.\n\n### Serveless - Lambda\n\nRefer to the [Blog Post here](https://www.ravirajag.dev/blog/mlops-serverless) for detailed instructions on configuring lambda with the docker image and invoking it using a API.\n\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.\n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_8_serverless/app.py",
    "content": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredictor = ColaONNXPredictor(\"./models/model.onnx\")\n\n@app.get(\"/\")\nasync def home_page():\n    return \"<h2>Sample prediction API</h2>\"\n\n\n@app.get(\"/predict\")\nasync def get_prediction(text: str):\n    result =  predictor.predict(text)\n    return result"
  },
  {
    "path": "week_8_serverless/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_8_serverless/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_8_serverless/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_8_serverless/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_8_serverless/convert_model_to_onnx.py",
    "content": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom data import DataModule\n\nlogger = logging.getLogger(__name__)\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef convert_model(cfg):\n    root_dir = hydra.utils.get_original_cwd()\n    model_path = f\"{root_dir}/models/best-checkpoint.ckpt\"\n    logger.info(f\"Loading pre-trained model from: {model_path}\")\n    cola_model = ColaModel.load_from_checkpoint(model_path)\n\n    data_model = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    data_model.prepare_data()\n    data_model.setup()\n    input_batch = next(iter(data_model.train_dataloader()))\n    input_sample = {\n        \"input_ids\": input_batch[\"input_ids\"][0].unsqueeze(0),\n        \"attention_mask\": input_batch[\"attention_mask\"][0].unsqueeze(0),\n    }\n\n    # Export the model\n    logger.info(f\"Converting the model into ONNX format\")\n    torch.onnx.export(\n        cola_model,  # model being run\n        (\n            input_sample[\"input_ids\"],\n            input_sample[\"attention_mask\"],\n        ),  # model input (or a tuple for multiple inputs)\n        f\"{root_dir}/models/model.onnx\",  # where to save the model (can be a file or file-like object)\n        export_params=True,\n        opset_version=10,\n        input_names=[\"input_ids\", \"attention_mask\"],  # the model's input names\n        output_names=[\"output\"],  # the model's output names\n        dynamic_axes={\n            \"input_ids\": {0: \"batch_size\"},  # variable length axes\n            \"attention_mask\": {0: \"batch_size\"},\n            \"output\": {0: \"batch_size\"},\n        },\n    )\n\n    logger.info(\n        f\"Model converted successfully. ONNX format model is at: {root_dir}/models/model.onnx\"\n    )\n\n\nif __name__ == \"__main__\":\n    convert_model()\n"
  },
  {
    "path": "week_8_serverless/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_8_serverless/docker-compose.yml",
    "content": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:\n            - \"8000:8000\""
  },
  {
    "path": "week_8_serverless/dvcfiles/trained_model.dvc",
    "content": "wdir: ../models\nouts:\n- md5: 02f3b0034769ba45d758ad1bb9de33a3\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_8_serverless/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_8_serverless/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=1)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_8_serverless/inference_onnx.py",
    "content": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaONNXPredictor:\n    def __init__(self, model_path):\n        self.ort_session = ort.InferenceSession(model_path)\n        self.processor = DataModule()\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n\n        ort_inputs = {\n            \"input_ids\": np.expand_dims(processed[\"input_ids\"], axis=0),\n            \"attention_mask\": np.expand_dims(processed[\"attention_mask\"], axis=0),\n        }\n        ort_outs = self.ort_session.run(None, ort_inputs)\n        scores = softmax(ort_outs[0])[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": float(score)})\n        print(predictions)\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaONNXPredictor(\"./models/model.onnx\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_8_serverless/lambda_handler.py",
    "content": "\"\"\"\nLambda wrapper\n\"\"\"\n\nimport json\nfrom inference_onnx import ColaONNXPredictor\n\ninferencing_instance = ColaONNXPredictor(\"./models/model.onnx\")\n\ndef lambda_handler(event, context):\n\t\"\"\"\n\tLambda function handler for predicting linguistic acceptability of the given sentence\n\t\"\"\"\n\t\n\tif \"resource\" in event.keys():\n\t\tbody = event[\"body\"]\n\t\tbody = json.loads(body)\n\t\tprint(f\"Got the input: {body['sentence']}\")\n\t\tresponse = inferencing_instance.predict(body[\"sentence\"])\n\t\treturn {\n\t\t\t\"statusCode\": 200,\n\t\t\t\"headers\": {},\n\t\t\t\"body\": json.dumps(response)\n\t\t}\n\telse:\n\t\treturn inferencing_instance.predict(event[\"sentence\"])\n\nif __name__ == \"__main__\":\n\ttest = {\"sentence\": \"this is a sample sentence\"}\n\tlambda_handler(test, None)\n"
  },
  {
    "path": "week_8_serverless/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_8_serverless/parse_json.py",
    "content": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print(data)\ndata = eval(data)\nprint(data)\n\nwith open('test.json', 'w') as f:\n\tjson.dump(data, f)\n"
  },
  {
    "path": "week_8_serverless/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog\nfastapi\nuvicorn\n"
  },
  {
    "path": "week_8_serverless/requirements_inference.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nuvicorn\ndvc\ntokenizers==0.10.2\ntransformers==4.5.1"
  },
  {
    "path": "week_8_serverless/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    root_dir = hydra.utils.get_original_cwd()\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=f\"{root_dir}/models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        # limit_train_batches=cfg.training.limit_train_batches,\n        # limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_8_serverless/utils.py",
    "content": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n    def function(a):\n        pass\n    \"\"\"\n\n    @wraps(f)\n    def wrapper(*args, **kwargs):\n        start = time.time()\n        result = f(*args, **kwargs)\n        end = time.time()\n        print(\"function:%r took: %2.5f sec\" % (f.__name__, end - start))\n        return result\n\n    return wrapper\n"
  },
  {
    "path": "week_9_monitoring/Dockerfile",
    "content": "FROM amazon/aws-lambda-python\n\nARG AWS_ACCESS_KEY_ID\nARG AWS_SECRET_ACCESS_KEY\nARG MODEL_DIR=./models\nRUN mkdir $MODEL_DIR\n\nENV TRANSFORMERS_CACHE=$MODEL_DIR \\\n    TRANSFORMERS_VERBOSITY=error\n\nENV AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \\\n    AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY\n\nRUN yum install git -y && yum -y install gcc-c++\nCOPY requirements_inference.txt requirements_inference.txt\nRUN pip install -r requirements_inference.txt --no-cache-dir\nCOPY ./ ./\nENV PYTHONPATH \"${PYTHONPATH}:./\"\nENV LC_ALL=C.UTF-8\nENV LANG=C.UTF-8\nRUN pip install \"dvc[s3]\"\n# configuring remote server in dvc\nRUN dvc init --no-scm\nRUN dvc remote add -d model-store s3://models-dvc/trained_models/\n\n# pulling the trained model\nRUN dvc pull dvcfiles/trained_model.dvc\nRUN ls\nRUN python lambda_handler.py\nRUN chmod -R 0755 $MODEL_DIR\nCMD [ \"lambda_handler.lambda_handler\"]"
  },
  {
    "path": "week_9_monitoring/README.md",
    "content": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n## Requirements:\n\nThis project uses Python 3.8\n\nCreate a virtual env with the following command:\n\n```\nconda create --name project-setup python=3.8\nconda activate project-setup\n```\n\nInstall the requirements:\n\n```\npip install -r requirements.txt\n```\n\n## Running\n\n### Training\n\nAfter installing the requirements, in order to train the model simply run:\n\n```\npython train.py\n```\n\n### Monitoring\n\nOnce the training is completed in the end of the logs you will see something like:\n\n```\nwandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)\nwandb:\nwandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc\n```\n\nFollow the link to see the wandb dashboard which contains all the plots.\n\n### Versioning data\n\nRefer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)\n\n### Exporting model to ONNX\n\nOnce the model is trained, convert the model using the following command:\n\n```\npython convert_model_to_onnx.py\n```\n\n### Inference\n\n#### Inference using standard pytorch\n\n```\npython inference.py\n```\n\n#### Inference using ONNX Runtime\n\n```\npython inference_onnx.py\n```\n\n## S3 & ECR\n\nFollow the instructions mentioned in the [blog post](https://www.ravirajag.dev/blog/mlops-container-registry) for creating S3 bucket and ECR repository. \n\n### Configuring dvc\n\n```\ndvc init (this has to be done at root folder)\ndvc remote add -d model-store s3://models-dvc/trained_models/\n```\n\n### AWS credentials\n\nCreate the credentials as mentioned in the [blog post](https://www.ravirajag.dev/blog/mlops-container-registry)\n\n**Do not share the secrets with others**\n\nSet the ACCESS key and id values in environment variables.\n\n```\nexport AWS_ACCESS_KEY_ID=<ACCESS KEY ID>\nexport AWS_SECRET_ACCESS_KEY=<ACCESS SECRET>\n```\n\n### Trained model in DVC\n\nSdd the trained model(onnx) to dvc using the following command:\n\n```shell\ncd dvcfiles\ndvc add ../models/model.onnx --file trained_model.dvc\n```\n\nPush the model to remote storage\n\n```shell\ndvc push trained_model.dvc\n```\n\n### Docker\n\nInstall the docker using the [instructions here](https://docs.docker.com/engine/install/)\n\nBuild the image using the command\n\n```shell\ndocker build -t mlops-basics:latest .\n```\n\n**The default command in dockerfile is modified to support the lambda. If you want to run without lambda use the last weeks dockerfile.**\n\nThen run the container using the command\n\n```shell\ndocker run -p 8000:8000 --name inference_container mlops-basics:latest\n```\n\n(or)\n\nBuild and run the container using the command\n\n```shell\ndocker-compose up\n```\n\n### Pushing the image to ECR\n\nFollow the instructions mentioned in [blog post](https://www.ravirajag.dev/blog/mlops-container-registry) for creating ECR repository.\n\n- Authenticating docker client to ECR\n\n```\naws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 246113150184.dkr.ecr.us-west-2.amazonaws.com\n```\n\n- Tagging the image\n\n```\ndocker tag mlops-basics:latest 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest\n```\n\n- Pushing the image\n\n```\ndocker push 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest\n```\n\nRefer to `.github/workflows/build_docker_image.yaml` file for automatically creating the docker image with trained model and pushing it to ECR.\n\n### Serveless - Lambda\n\nRefer to the [Blog Post here](https://www.ravirajag.dev/blog/mlops-serverless) for detailed instructions on configuring lambda with the docker image and invoking it using a API.\n\n### Monitoring - Kibana\n\nRefer to the [Blog Post here](https://www.ravirajag.dev/blog/mlops-monitoring) for detailed instructions on configuring kibana using elasticsarch cluster and integrating with cloudwatch logs.\n\n\n### Running notebooks\n\nI am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.\n\nSince I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.\n\nTo make sure to use the virutalenv, run the following commands before running `jupyter lab`\n\n```\nconda install ipykernel\npython -m ipykernel install --user --name project-setup\npip install ipywidgets\n```"
  },
  {
    "path": "week_9_monitoring/app.py",
    "content": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredictor = ColaONNXPredictor(\"./models/model.onnx\")\n\n@app.get(\"/\")\nasync def home_page():\n    return \"<h2>Sample prediction API</h2>\"\n\n\n@app.get(\"/predict\")\nasync def get_prediction(text: str):\n    result =  predictor.predict(text)\n    return result"
  },
  {
    "path": "week_9_monitoring/configs/config.yaml",
    "content": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - override hydra/hydra_logging: colorlog"
  },
  {
    "path": "week_9_monitoring/configs/model/default.yaml",
    "content": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data"
  },
  {
    "path": "week_9_monitoring/configs/processing/default.yaml",
    "content": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_9_monitoring/configs/training/default.yaml",
    "content": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_train_batches}"
  },
  {
    "path": "week_9_monitoring/convert_model_to_onnx.py",
    "content": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom data import DataModule\n\nlogger = logging.getLogger(__name__)\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef convert_model(cfg):\n    root_dir = hydra.utils.get_original_cwd()\n    model_path = f\"{root_dir}/models/best-checkpoint.ckpt\"\n    logger.info(f\"Loading pre-trained model from: {model_path}\")\n    cola_model = ColaModel.load_from_checkpoint(model_path)\n\n    data_model = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    data_model.prepare_data()\n    data_model.setup()\n    input_batch = next(iter(data_model.train_dataloader()))\n    input_sample = {\n        \"input_ids\": input_batch[\"input_ids\"][0].unsqueeze(0),\n        \"attention_mask\": input_batch[\"attention_mask\"][0].unsqueeze(0),\n    }\n\n    # Export the model\n    logger.info(f\"Converting the model into ONNX format\")\n    torch.onnx.export(\n        cola_model,  # model being run\n        (\n            input_sample[\"input_ids\"],\n            input_sample[\"attention_mask\"],\n        ),  # model input (or a tuple for multiple inputs)\n        f\"{root_dir}/models/model.onnx\",  # where to save the model (can be a file or file-like object)\n        export_params=True,\n        opset_version=10,\n        input_names=[\"input_ids\", \"attention_mask\"],  # the model's input names\n        output_names=[\"output\"],  # the model's output names\n        dynamic_axes={\n            \"input_ids\": {0: \"batch_size\"},  # variable length axes\n            \"attention_mask\": {0: \"batch_size\"},\n            \"output\": {0: \"batch_size\"},\n        },\n    )\n\n    logger.info(\n        f\"Model converted successfully. ONNX format model is at: {root_dir}/models/model.onnx\"\n    )\n\n\nif __name__ == \"__main__\":\n    convert_model()\n"
  },
  {
    "path": "week_9_monitoring/data.py",
    "content": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        model_name=\"google/bert_uncased_L-2_H-128_A-2\",\n        batch_size=64,\n        max_length=128,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.max_length = max_length\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def prepare_data(self):\n        cola_dataset = load_dataset(\"glue\", \"cola\")\n        self.train_data = cola_dataset[\"train\"]\n        self.val_data = cola_dataset[\"validation\"]\n\n    def tokenize_data(self, example):\n        return self.tokenizer(\n            example[\"sentence\"],\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.max_length,\n        )\n\n    def setup(self, stage=None):\n        # we set up only relevant datasets when stage is specified\n        if stage == \"fit\" or stage is None:\n            self.train_data = self.train_data.map(self.tokenize_data, batched=True)\n            self.train_data.set_format(\n                type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"]\n            )\n\n            self.val_data = self.val_data.map(self.tokenize_data, batched=True)\n            self.val_data.set_format(\n                type=\"torch\",\n                columns=[\"input_ids\", \"attention_mask\", \"label\"],\n                output_all_columns=True,\n            )\n\n    def train_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.train_data, batch_size=self.batch_size, shuffle=True\n        )\n\n    def val_dataloader(self):\n        return torch.utils.data.DataLoader(\n            self.val_data, batch_size=self.batch_size, shuffle=False\n        )\n\n\nif __name__ == \"__main__\":\n    data_model = DataModule()\n    data_model.prepare_data()\n    data_model.setup()\n    print(next(iter(data_model.train_dataloader()))[\"input_ids\"].shape)\n"
  },
  {
    "path": "week_9_monitoring/docker-compose.yml",
    "content": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:\n            - \"8000:8000\""
  },
  {
    "path": "week_9_monitoring/dvcfiles/trained_model.dvc",
    "content": "wdir: ../models\nouts:\n- md5: 02f3b0034769ba45d758ad1bb9de33a3\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_9_monitoring/experimental_notebooks/data_exploration.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"cola_dataset = load_dataset('glue', 'cola')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DatasetDict({\\n\",\n       \"    train: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 8551\\n\",\n       \"    })\\n\",\n       \"    validation: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1043\\n\",\n       \"    })\\n\",\n       \"    test: Dataset({\\n\",\n       \"        features: ['sentence', 'label', 'idx'],\\n\",\n       \"        num_rows: 1063\\n\",\n       \"    })\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"cola_dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(8551, 1043, 1063)\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"len(train_dataset), len(val_dataset), len(test_dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': \\\"Our friends won't buy this analysis, let alone the next one we propose.\\\"}\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0,\\n\",\n       \" 'label': 1,\\n\",\n       \" 'sentence': 'The sailors rode the breeze clear of the rocks.'}\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'sentence': Value(dtype='string', id=None),\\n\",\n       \" 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\\n\",\n       \" 'idx': Value(dtype='int32', id=None)}\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c681f26df104422a4c21a216b351949\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [0, 1, 2, 3, 4],\\n\",\n       \" 'label': [1, 1, 1, 1, 1],\\n\",\n       \" 'sentence': [\\\"Our friends won't buy this analysis, let alone the next one we propose.\\\",\\n\",\n       \"  \\\"One more pseudo generalization and I'm giving up.\\\",\\n\",\n       \"  \\\"One more pseudo generalization or I'm giving up.\\\",\\n\",\n       \"  'The more we study verbs, the crazier they get.',\\n\",\n       \"  'Day by day the facts are getting murkier.']}\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7276a21736814e29b7df2af0bdee2dab\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'idx': [18, 20, 22, 23, 25],\\n\",\n       \" 'label': [0, 0, 0, 0, 0],\\n\",\n       \" 'sentence': ['They drank the pub.',\\n\",\n       \"  'The professor talked us.',\\n\",\n       \"  'We yelled ourselves.',\\n\",\n       \"  'We yelled Harry hoarse.',\\n\",\n       \"  'Harry coughed himself.']}\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenizing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google/bert_uncased_L-2_H-128_A-2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset = cola_dataset['train']\\n\",\n    \"val_dataset = cola_dataset['validation']\\n\",\n    \"test_dataset = cola_dataset['test']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Our friends won't buy this analysis, let alone the next one we propose.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(train_dataset[0]['sentence'])\\n\",\n    \"tokenizer(train_dataset[0]['sentence'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"\\\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\\\"\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode(examples):\\n\",\n    \"    return tokenizer(\\n\",\n    \"            examples[\\\"sentence\\\"],\\n\",\n    \"            truncation=True,\\n\",\n    \"            padding=\\\"max_length\\\",\\n\",\n    \"            max_length=512,\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a5205de7df394d5a800f2ee94d3c9106\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_dataset = train_dataset.map(encode, batched=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Formatting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Loader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         ...,\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0],\\n\",\n       \"         [1, 1, 1,  ..., 0, 0, 0]]),\\n\",\n       \" 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2028,  2062,  ...,     0,     0,     0],\\n\",\n       \"         ...,\\n\",\n       \"         [  101,  5965, 12808,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  2198, 10948,  ...,     0,     0,     0],\\n\",\n       \"         [  101,  3021, 24471,  ...,     0,     0,     0]]),\\n\",\n       \" 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\\n\",\n       \"         1, 0, 0, 1, 1, 1, 1, 1])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(dataloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\\n\",\n      \"torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for batch in dataloader:\\n\",\n    \"    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"day1\",\n   \"language\": \"python\",\n   \"name\": \"day1\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "week_9_monitoring/inference.py",
    "content": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n    def __init__(self, model_path):\n        self.model_path = model_path\n        self.model = ColaModel.load_from_checkpoint(model_path)\n        self.model.eval()\n        self.model.freeze()\n        self.processor = DataModule()\n        self.softmax = torch.nn.Softmax(dim=1)\n        self.lables = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n        logits = self.model(\n            torch.tensor([processed[\"input_ids\"]]),\n            torch.tensor([processed[\"attention_mask\"]]),\n        )\n        scores = self.softmax(logits[0]).tolist()[0]\n        predictions = []\n        for score, label in zip(scores, self.lables):\n            predictions.append({\"label\": label, \"score\": score})\n        return predictions\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaPredictor(\"./models/best-checkpoint.ckpt\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_9_monitoring/inference_onnx.py",
    "content": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaONNXPredictor:\n    def __init__(self, model_path):\n        self.ort_session = ort.InferenceSession(model_path)\n        self.processor = DataModule()\n        self.labels = [\"unacceptable\", \"acceptable\"]\n\n    @timing\n    def predict(self, text):\n        inference_sample = {\"sentence\": text}\n        processed = self.processor.tokenize_data(inference_sample)\n\n        ort_inputs = {\n            \"input_ids\": np.expand_dims(processed[\"input_ids\"], axis=0),\n            \"attention_mask\": np.expand_dims(processed[\"attention_mask\"], axis=0),\n        }\n        ort_outs = self.ort_session.run(None, ort_inputs)\n        scores = softmax(ort_outs[0])[0]\n        max_score_id = np.argmax(scores)\n        prediction ={}\n        prediction['label'] = self.labels[max_score_id]\n        prediction['score'] = round(float(scores[max_score_id]), 2)\n\n        result = {}\n        result['text'] = text\n        result['prediction'] = prediction\n        return result\n\n\nif __name__ == \"__main__\":\n    sentence = \"The boy is sitting on a bench\"\n    predictor = ColaONNXPredictor(\"./models/model.onnx\")\n    print(predictor.predict(sentence))\n    sentences = [\"The boy is sitting on a bench\"] * 10\n    for sentence in sentences:\n        predictor.predict(sentence)\n"
  },
  {
    "path": "week_9_monitoring/lambda_handler.py",
    "content": "\"\"\"\nLambda wrapper\n\"\"\"\n\nimport json\nimport logging\nfrom inference_onnx import ColaONNXPredictor\n\nlogging.basicConfig()\nlogger = logging.getLogger(__name__)\nlogger.setLevel(level=logging.DEBUG)\n\nlogger.info(f\"Loading the model\")\ninferencing_instance = ColaONNXPredictor(\"./models/model.onnx\")\n\n\ndef lambda_handler(event, context):\n\t\"\"\"\n\tLambda function handler for predicting linguistic acceptability of the given sentence\n\t\"\"\"\n\t\n\tif \"resource\" in event.keys():\n\t\tbody = event[\"body\"]\n\t\tbody = json.loads(body)\n\t\tlogger.info(f\"Got the input: {body['sentence']}\")\n\n\t\tresponse = inferencing_instance.predict(body[\"sentence\"])\n\t\tlogger.info(json.dumps(response))\n\t\treturn {\n\t\t\t\"statusCode\": 200,\n\t\t\t\"headers\": {},\n\t\t\t\"body\": json.dumps(response)\n\t\t}\n\telse:\n\t\tlogger.info(f\"Got the input: {event['sentence']}\")\n\t\tresponse = inferencing_instance.predict(event[\"sentence\"])\n\t\tlogger.info(json.dumps(response))\n\t\treturn response\n\nif __name__ == \"__main__\":\n\ttest = {\"sentence\": \"this is a sample sentence\"}\n\tlambda_handler(test, None)\n"
  },
  {
    "path": "week_9_monitoring/model.py",
    "content": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightning as pl\nfrom transformers import AutoModelForSequenceClassification\nfrom omegaconf import OmegaConf, DictConfig\nfrom sklearn.metrics import confusion_matrix\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n\nclass ColaModel(pl.LightningModule):\n    def __init__(self, model_name=\"google/bert_uncased_L-2_H-128_A-2\", lr=3e-5):\n        super(ColaModel, self).__init__()\n        self.save_hyperparameters()\n\n        self.bert = AutoModelForSequenceClassification.from_pretrained(\n            model_name, num_labels=2\n        )\n        self.num_classes = 2\n        self.train_accuracy_metric = torchmetrics.Accuracy()\n        self.val_accuracy_metric = torchmetrics.Accuracy()\n        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)\n        self.precision_macro_metric = torchmetrics.Precision(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.recall_macro_metric = torchmetrics.Recall(\n            average=\"macro\", num_classes=self.num_classes\n        )\n        self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n        self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n    def forward(self, input_ids, attention_mask, labels=None):\n        outputs = self.bert(\n            input_ids=input_ids, attention_mask=attention_mask, labels=labels\n        )\n        return outputs\n\n    def training_step(self, batch, batch_idx):\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        # loss = F.cross_entropy(logits, batch[\"label\"])\n        preds = torch.argmax(outputs.logits, 1)\n        train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n        self.log(\"train/loss\", outputs.loss, prog_bar=True, on_epoch=True)\n        self.log(\"train/acc\", train_acc, prog_bar=True, on_epoch=True)\n        return outputs.loss\n\n    def validation_step(self, batch, batch_idx):\n        labels = batch[\"label\"]\n        outputs = self.forward(\n            batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"]\n        )\n        preds = torch.argmax(outputs.logits, 1)\n\n        # Metrics\n        valid_acc = self.val_accuracy_metric(preds, labels)\n        precision_macro = self.precision_macro_metric(preds, labels)\n        recall_macro = self.recall_macro_metric(preds, labels)\n        precision_micro = self.precision_micro_metric(preds, labels)\n        recall_micro = self.recall_micro_metric(preds, labels)\n        f1 = self.f1_metric(preds, labels)\n\n        # Logging metrics\n        self.log(\"valid/loss\", outputs.loss, prog_bar=True, on_step=True)\n        self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True)\n        self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n        return {\"labels\": labels, \"logits\": outputs.logits}\n\n    def validation_epoch_end(self, outputs):\n        labels = torch.cat([x[\"labels\"] for x in outputs])\n        logits = torch.cat([x[\"logits\"] for x in outputs])\n        preds = torch.argmax(logits, 1)\n\n        ## There are multiple ways to track the metrics\n        # 1. Confusion matrix plotting using inbuilt W&B method\n        self.logger.experiment.log(\n            {\n                \"conf\": wandb.plot.confusion_matrix(\n                    probs=logits.numpy(), y_true=labels.numpy()\n                )\n            }\n        )\n\n        # 2. Confusion Matrix plotting using scikit-learn method\n        # wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})\n\n        # 3. Confusion Matric plotting using Seaborn\n        # data = confusion_matrix(labels.numpy(), preds.numpy())\n        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))\n        # df_cm.index.name = \"Actual\"\n        # df_cm.columns.name = \"Predicted\"\n        # plt.figure(figsize=(7, 4))\n        # plot = sns.heatmap(\n        #     df_cm, cmap=\"Blues\", annot=True, annot_kws={\"size\": 16}\n        # )  # font size\n        # self.logger.experiment.log({\"Confusion Matrix\": wandb.Image(plot)})\n\n        # self.logger.experiment.log(\n        #     {\"roc\": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}\n        # )\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n"
  },
  {
    "path": "week_9_monitoring/parse_json.py",
    "content": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print(data)\ndata = eval(data)\nprint(data)\n\nwith open('test.json', 'w') as f:\n\tjson.dump(data, f)\n"
  },
  {
    "path": "week_9_monitoring/requirements.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn\nhydra-core\nomegaconf\nhydra_colorlog\nfastapi\nuvicorn\n"
  },
  {
    "path": "week_9_monitoring/requirements_inference.txt",
    "content": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nuvicorn\ndvc\ntokenizers==0.10.2\ntransformers==4.5.1"
  },
  {
    "path": "week_9_monitoring/train.py",
    "content": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf.omegaconf import OmegaConf\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom data import DataModule\nfrom model import ColaModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplesVisualisationLogger(pl.Callback):\n    def __init__(self, datamodule):\n        super().__init__()\n\n        self.datamodule = datamodule\n\n    def on_validation_end(self, trainer, pl_module):\n        val_batch = next(iter(self.datamodule.val_dataloader()))\n        sentences = val_batch[\"sentence\"]\n\n        outputs = pl_module(val_batch[\"input_ids\"], val_batch[\"attention_mask\"])\n        preds = torch.argmax(outputs.logits, 1)\n        labels = val_batch[\"label\"]\n\n        df = pd.DataFrame(\n            {\"Sentence\": sentences, \"Label\": labels.numpy(), \"Predicted\": preds.numpy()}\n        )\n\n        wrong_df = df[df[\"Label\"] != df[\"Predicted\"]]\n        trainer.logger.experiment.log(\n            {\n                \"examples\": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),\n                \"global_step\": trainer.global_step,\n            }\n        )\n\n\n@hydra.main(config_path=\"./configs\", config_name=\"config\")\ndef main(cfg):\n    logger.info(OmegaConf.to_yaml(cfg, resolve=True))\n    logger.info(f\"Using the model: {cfg.model.name}\")\n    logger.info(f\"Using the tokenizer: {cfg.model.tokenizer}\")\n    cola_data = DataModule(\n        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length\n    )\n    cola_model = ColaModel(cfg.model.name)\n\n    root_dir = hydra.utils.get_original_cwd()\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=f\"{root_dir}/models\",\n        filename=\"best-checkpoint\",\n        monitor=\"valid/loss\",\n        mode=\"min\",\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n    )\n\n    wandb_logger = WandbLogger(project=\"MLOps Basics\", entity=\"raviraja\")\n    trainer = pl.Trainer(\n        max_epochs=cfg.training.max_epochs,\n        logger=wandb_logger,\n        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],\n        log_every_n_steps=cfg.training.log_every_n_steps,\n        deterministic=cfg.training.deterministic,\n        # limit_train_batches=cfg.training.limit_train_batches,\n        # limit_val_batches=cfg.training.limit_val_batches,\n    )\n    trainer.fit(cola_model, cola_data)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "week_9_monitoring/utils.py",
    "content": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n    def function(a):\n        pass\n    \"\"\"\n\n    @wraps(f)\n    def wrapper(*args, **kwargs):\n        start = time.time()\n        result = f(*args, **kwargs)\n        end = time.time()\n        print(\"function:%r took: %2.5f sec\" % (f.__name__, end - start))\n        return result\n\n    return wrapper\n"
  }
]