Full Code of graviraja/MLOps-Basics for AI

main 558adce8203a cached
167 files
502.1 KB
162.0k tokens
248 symbols
1 requests
Download .txt
Showing preview only (543K chars total). Download the full file or copy to clipboard to get everything.
Repository: graviraja/MLOps-Basics
Branch: main
Commit: 558adce8203a
Files: 167
Total size: 502.1 KB

Directory structure:
gitextract_g0pkv867/

├── .dvc/
│   ├── .gitignore
│   ├── config
│   └── plots/
│       ├── confusion.json
│       ├── confusion_normalized.json
│       ├── default.json
│       ├── linear.json
│       ├── scatter.json
│       └── smooth.json
├── .dvcignore
├── .github/
│   └── workflows/
│       ├── basic.yaml
│       └── build_docker_image.yaml
├── .gitignore
├── LICENSE
├── README.md
├── week_0_project_setup/
│   ├── README.md
│   ├── data.py
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_1_wandb_logging/
│   ├── README.md
│   ├── data.py
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_2_hydra_config/
│   ├── README.md
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── data.py
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_3_dvc/
│   ├── README.md
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── data.py
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_4_onnx/
│   ├── README.md
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── requirements.txt
│   ├── train.py
│   └── utils.py
├── week_5_docker/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
├── week_6_github_actions/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── parse_json.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
├── week_7_ecr/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── parse_json.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
├── week_8_serverless/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── lambda_handler.py
│   ├── model.py
│   ├── parse_json.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
└── week_9_monitoring/
    ├── Dockerfile
    ├── README.md
    ├── app.py
    ├── configs/
    │   ├── config.yaml
    │   ├── model/
    │   │   └── default.yaml
    │   ├── processing/
    │   │   └── default.yaml
    │   └── training/
    │       └── default.yaml
    ├── convert_model_to_onnx.py
    ├── data.py
    ├── docker-compose.yml
    ├── dvcfiles/
    │   └── trained_model.dvc
    ├── experimental_notebooks/
    │   └── data_exploration.ipynb
    ├── inference.py
    ├── inference_onnx.py
    ├── lambda_handler.py
    ├── model.py
    ├── parse_json.py
    ├── requirements.txt
    ├── requirements_inference.txt
    ├── train.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .dvc/.gitignore
================================================
/config.local
/tmp
/cache


================================================
FILE: .dvc/config
================================================
[core]
    remote = model-store
['remote "storage"']
    url = gdrive://19JK5AFbqOBlrFVwDHjTrf9uvQFtS0954
['remote "model-store"']
    url = s3://models-dvc/trained_models/


================================================
FILE: .dvc/plots/confusion.json
================================================
{
    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
    "data": {
        "values": "<DVC_METRIC_DATA>"
    },
    "title": "<DVC_METRIC_TITLE>",
    "facet": {
        "field": "rev",
        "type": "nominal"
    },
    "spec": {
        "transform": [
            {
                "aggregate": [
                    {
                        "op": "count",
                        "as": "xy_count"
                    }
                ],
                "groupby": [
                    "<DVC_METRIC_Y>",
                    "<DVC_METRIC_X>"
                ]
            },
            {
                "impute": "xy_count",
                "groupby": [
                    "rev",
                    "<DVC_METRIC_Y>"
                ],
                "key": "<DVC_METRIC_X>",
                "value": 0
            },
            {
                "impute": "xy_count",
                "groupby": [
                    "rev",
                    "<DVC_METRIC_X>"
                ],
                "key": "<DVC_METRIC_Y>",
                "value": 0
            },
            {
                "joinaggregate": [
                    {
                        "op": "max",
                        "field": "xy_count",
                        "as": "max_count"
                    }
                ],
                "groupby": []
            },
            {
                "calculate": "datum.xy_count / datum.max_count",
                "as": "percent_of_max"
            }
        ],
        "encoding": {
            "x": {
                "field": "<DVC_METRIC_X>",
                "type": "nominal",
                "sort": "ascending",
                "title": "<DVC_METRIC_X_LABEL>"
            },
            "y": {
                "field": "<DVC_METRIC_Y>",
                "type": "nominal",
                "sort": "ascending",
                "title": "<DVC_METRIC_Y_LABEL>"
            }
        },
        "layer": [
            {
                "mark": "rect",
                "width": 300,
                "height": 300,
                "encoding": {
                    "color": {
                        "field": "xy_count",
                        "type": "quantitative",
                        "title": "",
                        "scale": {
                            "domainMin": 0,
                            "nice": true
                        }
                    }
                }
            },
            {
                "mark": "text",
                "encoding": {
                    "text": {
                        "field": "xy_count",
                        "type": "quantitative"
                    },
                    "color": {
                        "condition": {
                            "test": "datum.percent_of_max > 0.5",
                            "value": "white"
                        },
                        "value": "black"
                    }
                }
            }
        ]
    }
}


================================================
FILE: .dvc/plots/confusion_normalized.json
================================================
{
    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
    "data": {
        "values": "<DVC_METRIC_DATA>"
    },
    "title": "<DVC_METRIC_TITLE>",
    "facet": {
        "field": "rev",
        "type": "nominal"
    },
    "spec": {
        "transform": [
            {
                "aggregate": [
                    {
                        "op": "count",
                        "as": "xy_count"
                    }
                ],
                "groupby": [
                    "<DVC_METRIC_Y>",
                    "<DVC_METRIC_X>"
                ]
            },
            {
                "impute": "xy_count",
                "groupby": [
                    "rev",
                    "<DVC_METRIC_Y>"
                ],
                "key": "<DVC_METRIC_X>",
                "value": 0
            },
            {
                "impute": "xy_count",
                "groupby": [
                    "rev",
                    "<DVC_METRIC_X>"
                ],
                "key": "<DVC_METRIC_Y>",
                "value": 0
            },
            {
                "joinaggregate": [
                    {
                        "op": "sum",
                        "field": "xy_count",
                        "as": "sum_y"
                    }
                ],
                "groupby": [
                    "<DVC_METRIC_Y>"
                ]
            },
            {
                "calculate": "datum.xy_count / datum.sum_y",
                "as": "percent_of_y"
            }
        ],
        "encoding": {
            "x": {
                "field": "<DVC_METRIC_X>",
                "type": "nominal",
                "sort": "ascending",
                "title": "<DVC_METRIC_X_LABEL>"
            },
            "y": {
                "field": "<DVC_METRIC_Y>",
                "type": "nominal",
                "sort": "ascending",
                "title": "<DVC_METRIC_Y_LABEL>"
            }
        },
        "layer": [
            {
                "mark": "rect",
                "width": 300,
                "height": 300,
                "encoding": {
                    "color": {
                        "field": "percent_of_y",
                        "type": "quantitative",
                        "title": "",
                        "scale": {
                            "domain": [
                                0,
                                1
                            ]
                        }
                    }
                }
            },
            {
                "mark": "text",
                "encoding": {
                    "text": {
                        "field": "percent_of_y",
                        "type": "quantitative",
                        "format": ".2f"
                    },
                    "color": {
                        "condition": {
                            "test": "datum.percent_of_y > 0.5",
                            "value": "white"
                        },
                        "value": "black"
                    }
                }
            }
        ]
    }
}


================================================
FILE: .dvc/plots/default.json
================================================
{
    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
    "data": {
        "values": "<DVC_METRIC_DATA>"
    },
    "title": "<DVC_METRIC_TITLE>",
    "width": 300,
    "height": 300,
    "mark": {
        "type": "line"
    },
    "encoding": {
        "x": {
            "field": "<DVC_METRIC_X>",
            "type": "quantitative",
            "title": "<DVC_METRIC_X_LABEL>"
        },
        "y": {
            "field": "<DVC_METRIC_Y>",
            "type": "quantitative",
            "title": "<DVC_METRIC_Y_LABEL>",
            "scale": {
                "zero": false
            }
        },
        "color": {
            "field": "rev",
            "type": "nominal"
        }
    }
}


================================================
FILE: .dvc/plots/linear.json
================================================
{
    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
    "data": {
        "values": "<DVC_METRIC_DATA>"
    },
    "title": "<DVC_METRIC_TITLE>",
    "width": 300,
    "height": 300,
    "layer": [
        {
            "encoding": {
                "x": {
                    "field": "<DVC_METRIC_X>",
                    "type": "quantitative",
                    "title": "<DVC_METRIC_X_LABEL>"
                },
                "y": {
                    "field": "<DVC_METRIC_Y>",
                    "type": "quantitative",
                    "title": "<DVC_METRIC_Y_LABEL>",
                    "scale": {
                        "zero": false
                    }
                },
                "color": {
                    "field": "rev",
                    "type": "nominal"
                }
            },
            "layer": [
                {
                    "mark": "line"
                },
                {
                    "selection": {
                        "label": {
                            "type": "single",
                            "nearest": true,
                            "on": "mouseover",
                            "encodings": [
                                "x"
                            ],
                            "empty": "none",
                            "clear": "mouseout"
                        }
                    },
                    "mark": "point",
                    "encoding": {
                        "opacity": {
                            "condition": {
                                "selection": "label",
                                "value": 1
                            },
                            "value": 0
                        }
                    }
                }
            ]
        },
        {
            "transform": [
                {
                    "filter": {
                        "selection": "label"
                    }
                }
            ],
            "layer": [
                {
                    "mark": {
                        "type": "rule",
                        "color": "gray"
                    },
                    "encoding": {
                        "x": {
                            "field": "<DVC_METRIC_X>",
                            "type": "quantitative"
                        }
                    }
                },
                {
                    "encoding": {
                        "text": {
                            "type": "quantitative",
                            "field": "<DVC_METRIC_Y>"
                        },
                        "x": {
                            "field": "<DVC_METRIC_X>",
                            "type": "quantitative"
                        },
                        "y": {
                            "field": "<DVC_METRIC_Y>",
                            "type": "quantitative"
                        }
                    },
                    "layer": [
                        {
                            "mark": {
                                "type": "text",
                                "align": "left",
                                "dx": 5,
                                "dy": -5
                            },
                            "encoding": {
                                "color": {
                                    "type": "nominal",
                                    "field": "rev"
                                }
                            }
                        }
                    ]
                }
            ]
        }
    ]
}


================================================
FILE: .dvc/plots/scatter.json
================================================
{
    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
    "data": {
        "values": "<DVC_METRIC_DATA>"
    },
    "title": "<DVC_METRIC_TITLE>",
    "width": 300,
    "height": 300,
    "layer": [
        {
            "encoding": {
                "x": {
                    "field": "<DVC_METRIC_X>",
                    "type": "quantitative",
                    "title": "<DVC_METRIC_X_LABEL>"
                },
                "y": {
                    "field": "<DVC_METRIC_Y>",
                    "type": "quantitative",
                    "title": "<DVC_METRIC_Y_LABEL>",
                    "scale": {
                        "zero": false
                    }
                },
                "color": {
                    "field": "rev",
                    "type": "nominal"
                }
            },
            "layer": [
                {
                    "mark": "point"
                },
                {
                    "selection": {
                        "label": {
                            "type": "single",
                            "nearest": true,
                            "on": "mouseover",
                            "encodings": [
                                "x"
                            ],
                            "empty": "none",
                            "clear": "mouseout"
                        }
                    },
                    "mark": "point",
                    "encoding": {
                        "opacity": {
                            "condition": {
                                "selection": "label",
                                "value": 1
                            },
                            "value": 0
                        }
                    }
                }
            ]
        },
        {
            "transform": [
                {
                    "filter": {
                        "selection": "label"
                    }
                }
            ],
            "layer": [
                {
                    "encoding": {
                        "text": {
                            "type": "quantitative",
                            "field": "<DVC_METRIC_Y>"
                        },
                        "x": {
                            "field": "<DVC_METRIC_X>",
                            "type": "quantitative"
                        },
                        "y": {
                            "field": "<DVC_METRIC_Y>",
                            "type": "quantitative"
                        }
                    },
                    "layer": [
                        {
                            "mark": {
                                "type": "text",
                                "align": "left",
                                "dx": 5,
                                "dy": -5
                            },
                            "encoding": {
                                "color": {
                                    "type": "nominal",
                                    "field": "rev"
                                }
                            }
                        }
                    ]
                }
            ]
        }
    ]
}


================================================
FILE: .dvc/plots/smooth.json
================================================
{
    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
    "data": {
        "values": "<DVC_METRIC_DATA>"
    },
    "title": "<DVC_METRIC_TITLE>",
    "mark": {
        "type": "line"
    },
    "encoding": {
        "x": {
            "field": "<DVC_METRIC_X>",
            "type": "quantitative",
            "title": "<DVC_METRIC_X_LABEL>"
        },
        "y": {
            "field": "<DVC_METRIC_Y>",
            "type": "quantitative",
            "title": "<DVC_METRIC_Y_LABEL>",
            "scale": {
                "zero": false
            }
        },
        "color": {
            "field": "rev",
            "type": "nominal"
        }
    },
    "transform": [
        {
            "loess": "<DVC_METRIC_Y>",
            "on": "<DVC_METRIC_X>",
            "groupby": [
                "rev"
            ],
            "bandwidth": 0.3
        }
    ]
}


================================================
FILE: .dvcignore
================================================
# Add patterns of files dvc should ignore, which could improve
# the performance. Learn more at
# https://dvc.org/doc/user-guide/dvcignore


================================================
FILE: .github/workflows/basic.yaml
================================================
name: GitHub Actions Basic Flow
on: [push]
jobs:
  Basic-workflow:
    runs-on: ubuntu-latest
    steps:
      - name: Basic Information
        run: |
          echo "🎬 The job was automatically triggered by a ${{ github.event_name }} event."
          echo "💻 This job is now running on a ${{ runner.os }} server hosted by GitHub!"
          echo "🎋 Workflow is running on the branch ${{ github.ref }}"
      - name: Checking out the repository
        uses: actions/checkout@v2
      - name: Information after checking out
        run: |
          echo "💡 The ${{ github.repository }} repository has been cloned to the runner."
          echo "🖥️ The workflow is now ready to test your code on the runner."
      - name: List files in the repository
        run: |
          ls ${{ github.workspace }}
      - run: echo "🍏 This job's status is ${{ job.status }}."

================================================
FILE: .github/workflows/build_docker_image.yaml
================================================
name: Create Docker Container

on: [push]

jobs:
  mlops-container:
    runs-on: ubuntu-latest
    defaults:
      run:
        working-directory: ./week_9_monitoring
    steps:
    - name: Checkout
      uses: actions/checkout@v2
      with:
        ref: ${{ github.ref }}
    - name: Configure AWS Credentials
      uses: aws-actions/configure-aws-credentials@v1
      with:
        aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
        aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
        aws-region: us-west-2
    - name: Build container
      run: |
        docker build --build-arg AWS_ACCOUNT_ID=${{ secrets.AWS_ACCOUNT_ID }} \
                     --build-arg AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }} \
                     --build-arg AWS_SECRET_ACCESS_KEY=${{ secrets.AWS_SECRET_ACCESS_KEY }} \
                     --tag mlops-basics .
    - name: Push2ECR
      id: ecr
      uses: jwalton/gh-ecr-push@v1
      with:
        access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
        secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
        region: us-west-2
        image: mlops-basics:latest
    
    - name: Update lambda with image
      run: aws lambda update-function-code --function-name  MLOps-Basics --image-uri 246113150184.dkr.ecr.us-west-2.amazonaws.com/mlops-basics:latest


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
.vscode/
*/logs/*
*/models/*
*/wandb/*
*/outputs/*
*/multirun/*

.DS_Store
*/.DS_Store

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 raviraja

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# MLOps-Basics

 > There is nothing magic about magic. The magician merely understands something simple which doesn’t appear to be simple or natural to the untrained audience. Once you learn how to hold a card while making your hand look empty, you only need practice before you, too, can “do magic.” – Jeffrey Friedl in the book Mastering Regular Expressions

**Note: Please raise an issue for any suggestions, corrections, and feedback.**

The goal of the series is to understand the basics of MLOps like model building, monitoring, configurations, testing, packaging, deployment, cicd, etc.

![pl](images/summary.png)

## Week 0: Project Setup

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-project-setup-part1)

The project I have implemented is a simple classification problem. The scope of this week is to understand the following topics:

- `How to get the data?`
- `How to process the data?`
- `How to define dataloaders?`
- `How to declare the model?`
- `How to train the model?`
- `How to do the inference?`

![pl](images/pl.jpeg)

Following tech stack is used:

- [Huggingface Datasets](https://github.com/huggingface/datasets)
- [Huggingface Transformers](https://github.com/huggingface/transformers)
- [Pytorch Lightning](https://pytorch-lightning.readthedocs.io/)

## Week 1: Model monitoring - Weights and Biases

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-wandb-integration)

Tracking all the experiments like tweaking hyper-parameters, trying different models to test their performance and seeing the connection between model and the input data will help in developing a better model.

The scope of this week is to understand the following topics:

- `How to configure basic logging with W&B?`
- `How to compute metrics and log them in W&B?`
- `How to add plots in W&B?`
- `How to add data samples to W&B?`

![wannb](images/wandb.png)

Following tech stack is used:

- [Weights and Biases](https://wandb.ai/site)
- [torchmetrics](https://torchmetrics.readthedocs.io/)

References:

- [Tutorial on Pytorch Lightning + Weights & Bias](https://www.youtube.com/watch?v=hUXQm46TAKc)

- [WandB Documentation](https://docs.wandb.ai/)

## Week 2: Configurations - Hydra

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-hydra-config)

Configuration management is a necessary for managing complex software systems. Lack of configuration management can cause serious problems with reliability, uptime, and the ability to scale a system.

The scope of this week is to understand the following topics:

- `Basics of Hydra`
- `Overridding configurations`
- `Splitting configuration across multiple files`
- `Variable Interpolation`
- `How to run model with different parameter combinations?`

![hydra](images/hydra.png)

Following tech stack is used:

- [Hydra](https://hydra.cc/)

References

- [Hydra Documentation](https://hydra.cc/docs/intro)

- [Simone Tutorial on Hydra](https://www.sscardapane.it/tutorials/hydra-tutorial/#executing-multiple-runs)


## Week 3: Data Version Control - DVC

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-dvc)

Classical code version control systems are not designed to handle large files, which make cloning and storing the history impractical. Which are very common in Machine Learning.

The scope of this week is to understand the following topics:

- `Basics of DVC`
- `Initialising DVC`
- `Configuring Remote Storage`
- `Saving Model to the Remote Storage`
- `Versioning the models`

![dvc](images/dvc.png)

Following tech stack is used:

- [DVC](https://dvc.org/)

References

- [DVC Documentation](https://dvc.org/doc)

- [DVC Tutorial on Versioning data](https://www.youtube.com/watch?v=kLKBcPonMYw)

## Week 4: Model Packaging - ONNX

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-onnx)

Why do we need model packaging? Models can be built using any machine learning framework available out there (sklearn, tensorflow, pytorch, etc.). We might want to deploy models in different environments like (mobile, web, raspberry pi) or want to run in a different framework (trained in pytorch, inference in tensorflow).
A common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers will help a lot.

This is acheived by a community project `ONNX`.

The scope of this week is to understand the following topics:

- `What is ONNX?`

- `How to convert a trained model to ONNX format?`

- `What is ONNX Runtime?`

- `How to run ONNX converted model in ONNX Runtime?`

- `Comparisions`

![ONNX](images/onnx.jpeg)

Following tech stack is used:

- [ONNX](https://onnx.ai/)
- [ONNXRuntime](https://www.onnxruntime.ai/)

References

- [Abhishek Thakur tutorial on onnx model conversion](https://www.youtube.com/watch?v=7nutT3Aacyw)
- [Pytorch Lightning documentation on onnx conversion](https://pytorch-lightning.readthedocs.io/en/stable/common/production_inference.html)
- [Huggingface Blog on ONNXRuntime](https://medium.com/microsoftazure/accelerate-your-nlp-pipelines-using-hugging-face-transformers-and-onnx-runtime-2443578f4333)
- [Piotr Blog on onnx conversion](https://tugot17.github.io/data-science-blog/onnx/tutorial/2020/09/21/Exporting-lightning-model-to-onnx.html)


## Week 5: Model Packaging - Docker

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=easy&color=green"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-docker)

Why do we need packaging? We might have to share our application with others, and when they try to run the application most of the time it doesn’t run due to dependencies issues / OS related issues and for that, we say (famous quote across engineers) that `It works on my laptop/system`.

So for others to run the applications they have to set up the same environment as it was run on the host side which means a lot of manual configuration and installation of components.

The solution to these limitations is a technology called Containers.

By containerizing/packaging the application, we can run the application on any cloud platform to get advantages of managed services and autoscaling and reliability, and many more.

The most prominent tool to do the packaging of application is Docker 🛳

The scope of this week is to understand the following topics:

- `FastAPI wrapper`
- `Basics of Docker`
- `Building Docker Container`
- `Docker Compose`

![Docker](images/docker_flow.png)

References

- [Analytics vidhya blog](https://www.analyticsvidhya.com/blog/2021/06/a-hands-on-guide-to-containerized-your-machine-learning-workflow-with-docker/)


## Week 6: CI/CD - GitHub Actions

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-github-actions)

CI/CD is a coding philosophy and set of practices with which you can continuously build, test, and deploy iterative code changes.

This iterative process helps reduce the chance that you develop new code based on a buggy or failed previous versions. With this method, you strive to have less human intervention or even no intervention at all, from the development of new code until its deployment.

In this post, I will be going through the following topics:

- Basics of GitHub Actions
- First GitHub Action
- Creating Google Service Account
- Giving access to Service account
- Configuring DVC to use Google Service account
- Configuring Github Action

![Docker](images/basic_flow.png)

References

- [Configuring service account](https://dvc.org/doc/user-guide/setup-google-drive-remote)

- [Github actions](https://docs.github.com/en/actions/quickstart)


## Week 7: Container Registry - AWS ECR

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-container-registry)

A container registry is a place to store container images. A container image is a file comprised of multiple layers which can execute applications in a single instance. Hosting all the images in one stored location allows users to commit, identify and pull images when needed.

Amazon Simple Storage Service (S3) is a storage for the internet. It is designed for large-capacity, low-cost storage provision across multiple geographical regions.

In this week, I will be going through the following topics:

- `Basics of S3`

- `Programmatic access to S3`

- `Configuring AWS S3 as remote storage in DVC`

- `Basics of ECR`

- `Configuring GitHub Actions to use S3, ECR`

![Docker](images/ecr_flow.png)


## Week 8: Serverless Deployment - AWS Lambda

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-serverless)

A serverless architecture is a way to build and run applications and services without having to manage infrastructure. The application still runs on servers, but all the server management is done by third party service (AWS). We no longer have to provision, scale, and maintain servers to run the applications. By using a serverless architecture, developers can focus on their core product instead of worrying about managing and operating servers or runtimes, either in the cloud or on-premises.

In this week, I will be going through the following topics:

- `Basics of Serverless`

- `Basics of AWS Lambda`

- `Triggering Lambda with API Gateway`

- `Deploying Container using Lambda`

- `Automating deployment to Lambda using Github Actions`

![Docker](images/lambda_flow.png)


## Week 9: Prediction Monitoring - Kibana

<img src="https://img.shields.io/static/v1.svg?style=for-the-badge&label=difficulty&message=medium&color=orange"/>

Refer to the [Blog Post here](https://deep-learning-blogs.vercel.app/blog/mlops-monitoring)


Monitoring systems can help give us confidence that our systems are running smoothly and, in the event of a system failure, can quickly provide appropriate context when diagnosing the root cause.

Things we want to monitor during and training and inference are different. During training we are concered about whether the loss is decreasing or not, whether the model is overfitting, etc.

But, during inference, We like to have confidence that our model is making correct predictions.

There are many reasons why a model can fail to make useful predictions:

- The underlying data distribution has shifted over time and the model has gone stale. i.e inference data characteristics is different from the data characteristics used to train the model.

- The inference data stream contains edge cases (not seen during model training). In this scenarios model might perform poorly or can lead to errors.

- The model was misconfigured in its production deployment. (Configuration issues are common)

In all of these scenarios, the model could still make a `successful` prediction from a service perspective, but the predictions will likely not be useful. Monitoring machine learning models can help us detect such scenarios and intervene (e.g. trigger a model retraining/deployment pipeline).

In this week, I will be going through the following topics:

- `Basics of Cloudwatch Logs`

- `Creating Elastic Search Cluster`

- `Configuring Cloudwatch Logs with Elastic Search`

- `Creating Index Patterns in Kibana`

- `Creating Kibana Visualisations`

- `Creating Kibana Dashboard`

![Docker](images/kibana_flow.png)


================================================
FILE: week_0_project_setup/README.md
================================================

**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**

## Requirements:

This project uses Python 3.8

Create a virtual env with the following command:

```
conda create --name project-setup python=3.8
conda activate project-setup
```

Install the requirements:

```
pip install -r requirements.txt
```

## Running

### Training

After installing the requirements, in order to train the model simply run:

```
python train.py
```

### Inference

After training, update the model checkpoint path in the code and run

```
python inference.py
```

### Running notebooks

I am using [Jupyter lab](https://jupyter.org/install) to run the notebooks. 

Since I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.

To make sure to use the virutalenv, run the following commands before running `jupyter lab`

```
conda install ipykernel
python -m ipykernel install --user --name project-setup
pip install ipywidgets
```




================================================
FILE: week_0_project_setup/data.py
================================================
import torch
import datasets
import pytorch_lightning as pl

from datasets import load_dataset
from transformers import AutoTokenizer


class DataModule(pl.LightningDataModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", batch_size=32):
        super().__init__()

        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def prepare_data(self):
        cola_dataset = load_dataset("glue", "cola")
        self.train_data = cola_dataset["train"]
        self.val_data = cola_dataset["validation"]

    def tokenize_data(self, example):
        return self.tokenizer(
            example["sentence"],
            truncation=True,
            padding="max_length",
            max_length=512,
        )

    def setup(self, stage=None):
        # we set up only relevant datasets when stage is specified
        if stage == "fit" or stage is None:
            self.train_data = self.train_data.map(self.tokenize_data, batched=True)
            self.train_data.set_format(
                type="torch", columns=["input_ids", "attention_mask", "label"]
            )

            self.val_data = self.val_data.map(self.tokenize_data, batched=True)
            self.val_data.set_format(
                type="torch", columns=["input_ids", "attention_mask", "label"]
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_data, batch_size=self.batch_size, shuffle=False
        )


if __name__ == "__main__":
    data_model = DataModule()
    data_model.prepare_data()
    data_model.setup()
    print(next(iter(data_model.train_dataloader()))["input_ids"].shape)


================================================
FILE: week_0_project_setup/experimental_notebooks/data_exploration.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    }
   ],
   "source": [
    "cola_dataset = load_dataset('glue', 'cola')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 8551\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1043\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1063\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cola_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8551, 1043, 1063)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset), len(val_dataset), len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\"}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': 'The sailors rode the breeze clear of the rocks.'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sentence': Value(dtype='string', id=None),\n",
       " 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\n",
       " 'idx': Value(dtype='int32', id=None)}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c681f26df104422a4c21a216b351949",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [0, 1, 2, 3, 4],\n",
       " 'label': [1, 1, 1, 1, 1],\n",
       " 'sentence': [\"Our friends won't buy this analysis, let alone the next one we propose.\",\n",
       "  \"One more pseudo generalization and I'm giving up.\",\n",
       "  \"One more pseudo generalization or I'm giving up.\",\n",
       "  'The more we study verbs, the crazier they get.',\n",
       "  'Day by day the facts are getting murkier.']}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7276a21736814e29b7df2af0bdee2dab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [18, 20, 22, 23, 25],\n",
       " 'label': [0, 0, 0, 0, 0],\n",
       " 'sentence': ['They drank the pub.',\n",
       "  'The professor talked us.',\n",
       "  'We yelled ourselves.',\n",
       "  'We yelled Harry hoarse.',\n",
       "  'Harry coughed himself.']}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenizing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google/bert_uncased_L-2_H-128_A-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our friends won't buy this analysis, let alone the next one we propose.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(train_dataset[0]['sentence'])\n",
    "tokenizer(train_dataset[0]['sentence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\""
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode(examples):\n",
    "    return tokenizer(\n",
    "            examples[\"sentence\"],\n",
    "            truncation=True,\n",
    "            padding=\"max_length\",\n",
    "            max_length=512,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5205de7df394d5a800f2ee94d3c9106",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train_dataset = train_dataset.map(encode, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Formatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         ...,\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0]]),\n",
       " 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         ...,\n",
       "         [  101,  5965, 12808,  ...,     0,     0,     0],\n",
       "         [  101,  2198, 10948,  ...,     0,     0,     0],\n",
       "         [  101,  3021, 24471,  ...,     0,     0,     0]]),\n",
       " 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\n",
       "         1, 0, 0, 1, 1, 1, 1, 1])}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\n"
     ]
    }
   ],
   "source": [
    "for batch in dataloader:\n",
    "    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "day1",
   "language": "python",
   "name": "day1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: week_0_project_setup/inference.py
================================================
import torch
from model import ColaModel
from data import DataModule


class ColaPredictor:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = ColaModel.load_from_checkpoint(model_path)
        self.model.eval()
        self.model.freeze()
        self.processor = DataModule()
        self.softmax = torch.nn.Softmax(dim=0)
        self.lables = ["unacceptable", "acceptable"]

    def predict(self, text):
        inference_sample = {"sentence": text}
        processed = self.processor.tokenize_data(inference_sample)
        logits = self.model(
            torch.tensor([processed["input_ids"]]),
            torch.tensor([processed["attention_mask"]]),
        )
        scores = self.softmax(logits[0]).tolist()
        predictions = []
        for score, label in zip(scores, self.lables):
            predictions.append({"label": label, "score": score})
        return predictions


if __name__ == "__main__":
    sentence = "The boy is sitting on a bench"
    predictor = ColaPredictor("./models/epoch=0-step=267.ckpt")
    print(predictor.predict(sentence))


================================================
FILE: week_0_project_setup/model.py
================================================
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from transformers import AutoModel
from sklearn.metrics import accuracy_score


class ColaModel(pl.LightningModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=1e-2):
        super(ColaModel, self).__init__()
        self.save_hyperparameters()

        self.bert = AutoModel.from_pretrained(model_name)
        self.W = nn.Linear(self.bert.config.hidden_size, 2)
        self.num_classes = 2

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        h_cls = outputs.last_hidden_state[:, 0]
        logits = self.W(h_cls)
        return logits

    def training_step(self, batch, batch_idx):
        logits = self.forward(batch["input_ids"], batch["attention_mask"])
        loss = F.cross_entropy(logits, batch["label"])
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        logits = self.forward(batch["input_ids"], batch["attention_mask"])
        loss = F.cross_entropy(logits, batch["label"])
        _, preds = torch.max(logits, dim=1)
        val_acc = accuracy_score(preds.cpu(), batch["label"].cpu())
        val_acc = torch.tensor(val_acc)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", val_acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])


================================================
FILE: week_0_project_setup/requirements.txt
================================================
pytorch-lightning==1.2.10
datasets==1.6.2
transformers==4.5.1
scikit-learn==0.24.2

================================================
FILE: week_0_project_setup/train.py
================================================
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from data import DataModule
from model import ColaModel


def main():
    cola_data = DataModule()
    cola_model = ColaModel()

    checkpoint_callback = ModelCheckpoint(
        dirpath="./models", monitor="val_loss", mode="min"
    )
    early_stopping_callback = EarlyStopping(
        monitor="val_loss", patience=3, verbose=True, mode="min"
    )

    trainer = pl.Trainer(
        default_root_dir="logs",
        gpus=(1 if torch.cuda.is_available() else 0),
        max_epochs=5,
        fast_dev_run=False,
        logger=pl.loggers.TensorBoardLogger("logs/", name="cola", version=1),
        callbacks=[checkpoint_callback, early_stopping_callback],
    )
    trainer.fit(cola_model, cola_data)


if __name__ == "__main__":
    main()


================================================
FILE: week_1_wandb_logging/README.md
================================================

**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**

## Requirements:

This project uses Python 3.8

Create a virtual env with the following command:

```
conda create --name project-setup python=3.8
conda activate project-setup
```

Install the requirements:

```
pip install -r requirements.txt
```

## Running

### Training

After installing the requirements, in order to train the model simply run:

```
python train.py
```

### Monitoring

Once the training is completed in the end of the logs you will see something like:

```
wandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)
wandb: 
wandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc
```

Follow the link to see the wandb dashboard which contains all the plots.

### Inference

After training, update the model checkpoint path in the code and run

```
python inference.py
```

### Running notebooks

I am using [Jupyter lab](https://jupyter.org/install) to run the notebooks. 

Since I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.

To make sure to use the virutalenv, run the following commands before running `jupyter lab`

```
conda install ipykernel
python -m ipykernel install --user --name project-setup
pip install ipywidgets
```

================================================
FILE: week_1_wandb_logging/data.py
================================================
import torch
import datasets
import pytorch_lightning as pl

from datasets import load_dataset
from transformers import AutoTokenizer


class DataModule(pl.LightningDataModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", batch_size=64):
        super().__init__()

        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def prepare_data(self):
        cola_dataset = load_dataset("glue", "cola")
        self.train_data = cola_dataset["train"]
        self.val_data = cola_dataset["validation"]

    def tokenize_data(self, example):
        return self.tokenizer(
            example["sentence"],
            truncation=True,
            padding="max_length",
            max_length=128,
        )

    def setup(self, stage=None):
        # we set up only relevant datasets when stage is specified
        if stage == "fit" or stage is None:
            self.train_data = self.train_data.map(self.tokenize_data, batched=True)
            self.train_data.set_format(
                type="torch", columns=["input_ids", "attention_mask", "label"]
            )

            self.val_data = self.val_data.map(self.tokenize_data, batched=True)
            self.val_data.set_format(
                type="torch",
                columns=["input_ids", "attention_mask", "label"],
                output_all_columns=True,
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_data, batch_size=self.batch_size, shuffle=False
        )


if __name__ == "__main__":
    data_model = DataModule()
    data_model.prepare_data()
    data_model.setup()
    print(next(iter(data_model.train_dataloader()))["input_ids"].shape)


================================================
FILE: week_1_wandb_logging/experimental_notebooks/data_exploration.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    }
   ],
   "source": [
    "cola_dataset = load_dataset('glue', 'cola')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 8551\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1043\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1063\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cola_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8551, 1043, 1063)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset), len(val_dataset), len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\"}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': 'The sailors rode the breeze clear of the rocks.'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sentence': Value(dtype='string', id=None),\n",
       " 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\n",
       " 'idx': Value(dtype='int32', id=None)}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c681f26df104422a4c21a216b351949",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [0, 1, 2, 3, 4],\n",
       " 'label': [1, 1, 1, 1, 1],\n",
       " 'sentence': [\"Our friends won't buy this analysis, let alone the next one we propose.\",\n",
       "  \"One more pseudo generalization and I'm giving up.\",\n",
       "  \"One more pseudo generalization or I'm giving up.\",\n",
       "  'The more we study verbs, the crazier they get.',\n",
       "  'Day by day the facts are getting murkier.']}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7276a21736814e29b7df2af0bdee2dab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [18, 20, 22, 23, 25],\n",
       " 'label': [0, 0, 0, 0, 0],\n",
       " 'sentence': ['They drank the pub.',\n",
       "  'The professor talked us.',\n",
       "  'We yelled ourselves.',\n",
       "  'We yelled Harry hoarse.',\n",
       "  'Harry coughed himself.']}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenizing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google/bert_uncased_L-2_H-128_A-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our friends won't buy this analysis, let alone the next one we propose.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(train_dataset[0]['sentence'])\n",
    "tokenizer(train_dataset[0]['sentence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\""
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode(examples):\n",
    "    return tokenizer(\n",
    "            examples[\"sentence\"],\n",
    "            truncation=True,\n",
    "            padding=\"max_length\",\n",
    "            max_length=512,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5205de7df394d5a800f2ee94d3c9106",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train_dataset = train_dataset.map(encode, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Formatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         ...,\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0]]),\n",
       " 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         ...,\n",
       "         [  101,  5965, 12808,  ...,     0,     0,     0],\n",
       "         [  101,  2198, 10948,  ...,     0,     0,     0],\n",
       "         [  101,  3021, 24471,  ...,     0,     0,     0]]),\n",
       " 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\n",
       "         1, 0, 0, 1, 1, 1, 1, 1])}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\n"
     ]
    }
   ],
   "source": [
    "for batch in dataloader:\n",
    "    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "day1",
   "language": "python",
   "name": "day1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: week_1_wandb_logging/inference.py
================================================
import torch
from model import ColaModel
from data import DataModule


class ColaPredictor:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = ColaModel.load_from_checkpoint(model_path)
        self.model.eval()
        self.model.freeze()
        self.processor = DataModule()
        self.softmax = torch.nn.Softmax(dim=0)
        self.lables = ["unacceptable", "acceptable"]

    def predict(self, text):
        inference_sample = {"sentence": text}
        processed = self.processor.tokenize_data(inference_sample)
        logits = self.model(
            torch.tensor([processed["input_ids"]]),
            torch.tensor([processed["attention_mask"]]),
        )
        scores = self.softmax(logits[0]).tolist()
        predictions = []
        for score, label in zip(scores, self.lables):
            predictions.append({"label": label, "score": score})
        return predictions


if __name__ == "__main__":
    sentence = "The boy is sitting on a bench"
    predictor = ColaPredictor("./models/epoch=0-step=267.ckpt")
    print(predictor.predict(sentence))


================================================
FILE: week_1_wandb_logging/model.py
================================================
import torch
import wandb
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification
import torchmetrics
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


class ColaModel(pl.LightningModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=3e-5):
        super(ColaModel, self).__init__()
        self.save_hyperparameters()

        self.bert = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=2
        )
        self.num_classes = 2
        self.train_accuracy_metric = torchmetrics.Accuracy()
        self.val_accuracy_metric = torchmetrics.Accuracy()
        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)
        self.precision_macro_metric = torchmetrics.Precision(
            average="macro", num_classes=self.num_classes
        )
        self.recall_macro_metric = torchmetrics.Recall(
            average="macro", num_classes=self.num_classes
        )
        self.precision_micro_metric = torchmetrics.Precision(average="micro")
        self.recall_micro_metric = torchmetrics.Recall(average="micro")

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )
        return outputs

    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            batch["input_ids"], batch["attention_mask"], labels=batch["label"]
        )
        # loss = F.cross_entropy(logits, batch["label"])
        preds = torch.argmax(outputs.logits, 1)
        train_acc = self.train_accuracy_metric(preds, batch["label"])
        self.log("train/loss", outputs.loss, prog_bar=True, on_epoch=True)
        self.log("train/acc", train_acc, prog_bar=True, on_epoch=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        outputs = self.forward(
            batch["input_ids"], batch["attention_mask"], labels=batch["label"]
        )
        preds = torch.argmax(outputs.logits, 1)

        # Metrics
        valid_acc = self.val_accuracy_metric(preds, labels)
        precision_macro = self.precision_macro_metric(preds, labels)
        recall_macro = self.recall_macro_metric(preds, labels)
        precision_micro = self.precision_micro_metric(preds, labels)
        recall_micro = self.recall_micro_metric(preds, labels)
        f1 = self.f1_metric(preds, labels)

        # Logging metrics
        self.log("valid/loss", outputs.loss, prog_bar=True, on_step=True)
        self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True)
        self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True)
        self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True)
        self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True)
        self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True)
        self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
        return {"labels": labels, "logits": outputs.logits}

    def validation_epoch_end(self, outputs):
        labels = torch.cat([x["labels"] for x in outputs])
        logits = torch.cat([x["logits"] for x in outputs])
        preds = torch.argmax(logits, 1)

        ## There are multiple ways to track the metrics
        # 1. Confusion matrix plotting using inbuilt W&B method
        self.logger.experiment.log(
            {
                "conf": wandb.plot.confusion_matrix(
                    probs=logits.numpy(), y_true=labels.numpy()
                )
            }
        )

        # 2. Confusion Matrix plotting using scikit-learn method
        # wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})

        # 3. Confusion Matric plotting using Seaborn
        # data = confusion_matrix(labels.numpy(), preds.numpy())
        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))
        # df_cm.index.name = "Actual"
        # df_cm.columns.name = "Predicted"
        # plt.figure(figsize=(7, 4))
        # plot = sns.heatmap(
        #     df_cm, cmap="Blues", annot=True, annot_kws={"size": 16}
        # )  # font size
        # self.logger.experiment.log({"Confusion Matrix": wandb.Image(plot)})

        # self.logger.experiment.log(
        #     {"roc": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}
        # )

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])


================================================
FILE: week_1_wandb_logging/requirements.txt
================================================
pytorch-lightning==1.2.10
datasets==1.6.2
transformers==4.5.1
scikit-learn==0.24.2
wandb
torchmetrics
matplotlib
seaborn

================================================
FILE: week_1_wandb_logging/train.py
================================================
import torch
import wandb
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from data import DataModule
from model import ColaModel


class SamplesVisualisationLogger(pl.Callback):
    def __init__(self, datamodule):
        super().__init__()

        self.datamodule = datamodule

    def on_validation_end(self, trainer, pl_module):
        val_batch = next(iter(self.datamodule.val_dataloader()))
        sentences = val_batch["sentence"]

        outputs = pl_module(val_batch["input_ids"], val_batch["attention_mask"])
        preds = torch.argmax(outputs.logits, 1)
        labels = val_batch["label"]

        df = pd.DataFrame(
            {"Sentence": sentences, "Label": labels.numpy(), "Predicted": preds.numpy()}
        )

        wrong_df = df[df["Label"] != df["Predicted"]]
        trainer.logger.experiment.log(
            {
                "examples": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),
                "global_step": trainer.global_step,
            }
        )


def main():
    cola_data = DataModule()
    cola_model = ColaModel()

    checkpoint_callback = ModelCheckpoint(
        dirpath="./models",
        filename="best-checkpoint.ckpt",
        monitor="valid/loss",
        mode="min",
    )

    early_stopping_callback = EarlyStopping(
        monitor="valid/loss", patience=3, verbose=True, mode="min"
    )

    wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
    trainer = pl.Trainer(
        max_epochs=1,
        logger=wandb_logger,
        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
        log_every_n_steps=10,
        deterministic=True,
        # limit_train_batches=0.25,
        # limit_val_batches=0.25
    )
    trainer.fit(cola_model, cola_data)


if __name__ == "__main__":
    main()


================================================
FILE: week_2_hydra_config/README.md
================================================

**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**

## Requirements:

This project uses Python 3.8

Create a virtual env with the following command:

```
conda create --name project-setup python=3.8
conda activate project-setup
```

Install the requirements:

```
pip install -r requirements.txt
```

## Running

### Training

After installing the requirements, in order to train the model simply run:

```
python train.py
```

### Monitoring

Once the training is completed in the end of the logs you will see something like:

```
wandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)
wandb: 
wandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc
```

Follow the link to see the wandb dashboard which contains all the plots.

### Inference

After training, update the model checkpoint path in the code and run

```
python inference.py
```

### Running notebooks

I am using [Jupyter lab](https://jupyter.org/install) to run the notebooks. 

Since I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.

To make sure to use the virutalenv, run the following commands before running `jupyter lab`

```
conda install ipykernel
python -m ipykernel install --user --name project-setup
pip install ipywidgets
```

================================================
FILE: week_2_hydra_config/configs/config.yaml
================================================
defaults:
  - model: default
  - processing: default
  - training: default
  - override hydra/job_logging: colorlog
  - override hydra/hydra_logging: colorlog

================================================
FILE: week_2_hydra_config/configs/model/default.yaml
================================================
name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier
tokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data

================================================
FILE: week_2_hydra_config/configs/processing/default.yaml
================================================
batch_size: 64
max_length: 128

================================================
FILE: week_2_hydra_config/configs/training/default.yaml
================================================
max_epochs: 1
log_every_n_steps: 10
deterministic: true
limit_train_batches: 0.25
limit_val_batches: ${training.limit_train_batches}

================================================
FILE: week_2_hydra_config/data.py
================================================
import torch
import datasets
import pytorch_lightning as pl

from datasets import load_dataset
from transformers import AutoTokenizer


class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        model_name="google/bert_uncased_L-2_H-128_A-2",
        batch_size=64,
        max_length=128,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def prepare_data(self):
        cola_dataset = load_dataset("glue", "cola")
        self.train_data = cola_dataset["train"]
        self.val_data = cola_dataset["validation"]

    def tokenize_data(self, example):
        return self.tokenizer(
            example["sentence"],
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
        )

    def setup(self, stage=None):
        # we set up only relevant datasets when stage is specified
        if stage == "fit" or stage is None:
            self.train_data = self.train_data.map(self.tokenize_data, batched=True)
            self.train_data.set_format(
                type="torch", columns=["input_ids", "attention_mask", "label"]
            )

            self.val_data = self.val_data.map(self.tokenize_data, batched=True)
            self.val_data.set_format(
                type="torch",
                columns=["input_ids", "attention_mask", "label"],
                output_all_columns=True,
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_data, batch_size=self.batch_size, shuffle=False
        )


if __name__ == "__main__":
    data_model = DataModule()
    data_model.prepare_data()
    data_model.setup()
    print(next(iter(data_model.train_dataloader()))["input_ids"].shape)


================================================
FILE: week_2_hydra_config/experimental_notebooks/data_exploration.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    }
   ],
   "source": [
    "cola_dataset = load_dataset('glue', 'cola')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 8551\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1043\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1063\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cola_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8551, 1043, 1063)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset), len(val_dataset), len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\"}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': 'The sailors rode the breeze clear of the rocks.'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sentence': Value(dtype='string', id=None),\n",
       " 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\n",
       " 'idx': Value(dtype='int32', id=None)}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c681f26df104422a4c21a216b351949",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [0, 1, 2, 3, 4],\n",
       " 'label': [1, 1, 1, 1, 1],\n",
       " 'sentence': [\"Our friends won't buy this analysis, let alone the next one we propose.\",\n",
       "  \"One more pseudo generalization and I'm giving up.\",\n",
       "  \"One more pseudo generalization or I'm giving up.\",\n",
       "  'The more we study verbs, the crazier they get.',\n",
       "  'Day by day the facts are getting murkier.']}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7276a21736814e29b7df2af0bdee2dab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [18, 20, 22, 23, 25],\n",
       " 'label': [0, 0, 0, 0, 0],\n",
       " 'sentence': ['They drank the pub.',\n",
       "  'The professor talked us.',\n",
       "  'We yelled ourselves.',\n",
       "  'We yelled Harry hoarse.',\n",
       "  'Harry coughed himself.']}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenizing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google/bert_uncased_L-2_H-128_A-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our friends won't buy this analysis, let alone the next one we propose.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(train_dataset[0]['sentence'])\n",
    "tokenizer(train_dataset[0]['sentence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\""
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode(examples):\n",
    "    return tokenizer(\n",
    "            examples[\"sentence\"],\n",
    "            truncation=True,\n",
    "            padding=\"max_length\",\n",
    "            max_length=512,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5205de7df394d5a800f2ee94d3c9106",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train_dataset = train_dataset.map(encode, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Formatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         ...,\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0]]),\n",
       " 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         ...,\n",
       "         [  101,  5965, 12808,  ...,     0,     0,     0],\n",
       "         [  101,  2198, 10948,  ...,     0,     0,     0],\n",
       "         [  101,  3021, 24471,  ...,     0,     0,     0]]),\n",
       " 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\n",
       "         1, 0, 0, 1, 1, 1, 1, 1])}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([7, 512]) torch.Size([7, 512]) torch.Size([7])\n"
     ]
    }
   ],
   "source": [
    "for batch in dataloader:\n",
    "    print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['label'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "day1",
   "language": "python",
   "name": "day1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: week_2_hydra_config/inference.py
================================================
import torch
from model import ColaModel
from data import DataModule


class ColaPredictor:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = ColaModel.load_from_checkpoint(model_path)
        self.model.eval()
        self.model.freeze()
        self.processor = DataModule()
        self.softmax = torch.nn.Softmax(dim=0)
        self.lables = ["unacceptable", "acceptable"]

    def predict(self, text):
        inference_sample = {"sentence": text}
        processed = self.processor.tokenize_data(inference_sample)
        logits = self.model(
            torch.tensor([processed["input_ids"]]),
            torch.tensor([processed["attention_mask"]]),
        )
        scores = self.softmax(logits[0]).tolist()[0]
        predictions = []
        for score, label in zip(scores, self.lables):
            predictions.append({"label": label, "score": score})
        return predictions


if __name__ == "__main__":
    sentence = "The boy is sitting on a bench"
    predictor = ColaPredictor("./models/best-checkpoint.ckpt")
    print(predictor.predict(sentence))


================================================
FILE: week_2_hydra_config/model.py
================================================
import torch
import wandb
import hydra
import numpy as np
import pandas as pd
import torchmetrics
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification
from omegaconf import OmegaConf, DictConfig
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


class ColaModel(pl.LightningModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=3e-5):
        super(ColaModel, self).__init__()
        self.save_hyperparameters()

        self.bert = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=2
        )
        self.num_classes = 2
        self.train_accuracy_metric = torchmetrics.Accuracy()
        self.val_accuracy_metric = torchmetrics.Accuracy()
        self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)
        self.precision_macro_metric = torchmetrics.Precision(
            average="macro", num_classes=self.num_classes
        )
        self.recall_macro_metric = torchmetrics.Recall(
            average="macro", num_classes=self.num_classes
        )
        self.precision_micro_metric = torchmetrics.Precision(average="micro")
        self.recall_micro_metric = torchmetrics.Recall(average="micro")

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )
        return outputs

    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            batch["input_ids"], batch["attention_mask"], labels=batch["label"]
        )
        # loss = F.cross_entropy(logits, batch["label"])
        preds = torch.argmax(outputs.logits, 1)
        train_acc = self.train_accuracy_metric(preds, batch["label"])
        self.log("train/loss", outputs.loss, prog_bar=True, on_epoch=True)
        self.log("train/acc", train_acc, prog_bar=True, on_epoch=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        outputs = self.forward(
            batch["input_ids"], batch["attention_mask"], labels=batch["label"]
        )
        preds = torch.argmax(outputs.logits, 1)

        # Metrics
        valid_acc = self.val_accuracy_metric(preds, labels)
        precision_macro = self.precision_macro_metric(preds, labels)
        recall_macro = self.recall_macro_metric(preds, labels)
        precision_micro = self.precision_micro_metric(preds, labels)
        recall_micro = self.recall_micro_metric(preds, labels)
        f1 = self.f1_metric(preds, labels)

        # Logging metrics
        self.log("valid/loss", outputs.loss, prog_bar=True, on_step=True)
        self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True)
        self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True)
        self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True)
        self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True)
        self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True)
        self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
        return {"labels": labels, "logits": outputs.logits}

    def validation_epoch_end(self, outputs):
        labels = torch.cat([x["labels"] for x in outputs])
        logits = torch.cat([x["logits"] for x in outputs])
        preds = torch.argmax(logits, 1)

        ## There are multiple ways to track the metrics
        # 1. Confusion matrix plotting using inbuilt W&B method
        self.logger.experiment.log(
            {
                "conf": wandb.plot.confusion_matrix(
                    probs=logits.numpy(), y_true=labels.numpy()
                )
            }
        )

        # 2. Confusion Matrix plotting using scikit-learn method
        # wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})

        # 3. Confusion Matric plotting using Seaborn
        # data = confusion_matrix(labels.numpy(), preds.numpy())
        # df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))
        # df_cm.index.name = "Actual"
        # df_cm.columns.name = "Predicted"
        # plt.figure(figsize=(7, 4))
        # plot = sns.heatmap(
        #     df_cm, cmap="Blues", annot=True, annot_kws={"size": 16}
        # )  # font size
        # self.logger.experiment.log({"Confusion Matrix": wandb.Image(plot)})

        # self.logger.experiment.log(
        #     {"roc": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}
        # )

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])


================================================
FILE: week_2_hydra_config/requirements.txt
================================================
pytorch-lightning==1.2.10
datasets==1.6.2
transformers==4.5.1
scikit-learn==0.24.2
wandb
torchmetrics
matplotlib
seaborn
hydra-core
omegaconf
hydra_colorlog

================================================
FILE: week_2_hydra_config/train.py
================================================
import torch
import hydra
import wandb
import logging

import pandas as pd
import pytorch_lightning as pl
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from data import DataModule
from model import ColaModel

logger = logging.getLogger(__name__)


class SamplesVisualisationLogger(pl.Callback):
    def __init__(self, datamodule):
        super().__init__()

        self.datamodule = datamodule

    def on_validation_end(self, trainer, pl_module):
        val_batch = next(iter(self.datamodule.val_dataloader()))
        sentences = val_batch["sentence"]

        outputs = pl_module(val_batch["input_ids"], val_batch["attention_mask"])
        preds = torch.argmax(outputs.logits, 1)
        labels = val_batch["label"]

        df = pd.DataFrame(
            {"Sentence": sentences, "Label": labels.numpy(), "Predicted": preds.numpy()}
        )

        wrong_df = df[df["Label"] != df["Predicted"]]
        trainer.logger.experiment.log(
            {
                "examples": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),
                "global_step": trainer.global_step,
            }
        )


@hydra.main(config_path="./configs", config_name="config")
def main(cfg):
    logger.info(OmegaConf.to_yaml(cfg, resolve=True))
    logger.info(f"Using the model: {cfg.model.name}")
    logger.info(f"Using the tokenizer: {cfg.model.tokenizer}")
    cola_data = DataModule(
        cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length
    )
    cola_model = ColaModel(cfg.model.name)

    checkpoint_callback = ModelCheckpoint(
        dirpath="./models",
        filename="best-checkpoint",
        monitor="valid/loss",
        mode="min",
    )

    early_stopping_callback = EarlyStopping(
        monitor="valid/loss", patience=3, verbose=True, mode="min"
    )

    wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
    trainer = pl.Trainer(
        max_epochs=cfg.training.max_epochs,
        logger=wandb_logger,
        callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
        log_every_n_steps=cfg.training.log_every_n_steps,
        deterministic=cfg.training.deterministic,
        limit_train_batches=cfg.training.limit_train_batches,
        limit_val_batches=cfg.training.limit_val_batches,
    )
    trainer.fit(cola_model, cola_data)
    wandb.finish()


if __name__ == "__main__":
    main()


================================================
FILE: week_3_dvc/README.md
================================================

**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**

## Requirements:

This project uses Python 3.8

Create a virtual env with the following command:

```
conda create --name project-setup python=3.8
conda activate project-setup
```

Install the requirements:

```
pip install -r requirements.txt
```

## Running

### Training

After installing the requirements, in order to train the model simply run:

```
python train.py
```

### Monitoring

Once the training is completed in the end of the logs you will see something like:

```
wandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)
wandb:
wandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc
```

Follow the link to see the wandb dashboard which contains all the plots.

### Inference

After training, update the model checkpoint path in the code and run

```
python inference.py
```

### Versioning data

Refer to the blog: [DVC Configuration](https://www.ravirajag.dev/blog/mlops-dvc)

### Running notebooks

I am using [Jupyter lab](https://jupyter.org/install) to run the notebooks.

Since I am using a virtualenv, when I run the command `jupyter lab` it might or might not use the virtualenv.

To make sure to use the virutalenv, run the following commands before running `jupyter lab`

```
conda install ipykernel
python -m ipykernel install --user --name project-setup
pip install ipywidgets
```

================================================
FILE: week_3_dvc/configs/config.yaml
================================================
defaults:
  - model: default
  - processing: default
  - training: default
  - override hydra/job_logging: colorlog
  - override hydra/hydra_logging: colorlog

================================================
FILE: week_3_dvc/configs/model/default.yaml
================================================
name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier
tokenizer: google/bert_uncased_L-2_H-128_A-2        # tokenizer used for processing the data

================================================
FILE: week_3_dvc/configs/processing/default.yaml
================================================
batch_size: 64
max_length: 128

================================================
FILE: week_3_dvc/configs/training/default.yaml
================================================
max_epochs: 1
log_every_n_steps: 10
deterministic: true
limit_train_batches: 0.25
limit_val_batches: ${training.limit_train_batches}

================================================
FILE: week_3_dvc/data.py
================================================
import torch
import pytorch_lightning as pl

from datasets import load_dataset
from transformers import AutoTokenizer


class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        model_name="google/bert_uncased_L-2_H-128_A-2",
        batch_size=64,
        max_length=128,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def prepare_data(self):
        cola_dataset = load_dataset("glue", "cola")
        self.train_data = cola_dataset["train"]
        self.val_data = cola_dataset["validation"]

    def tokenize_data(self, example):
        return self.tokenizer(
            example["sentence"],
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
        )

    def setup(self, stage=None):
        # we set up only relevant datasets when stage is specified
        if stage == "fit" or stage is None:
            self.train_data = self.train_data.map(self.tokenize_data, batched=True)
            self.train_data.set_format(
                type="torch", columns=["input_ids", "attention_mask", "label"]
            )

            self.val_data = self.val_data.map(self.tokenize_data, batched=True)
            self.val_data.set_format(
                type="torch",
                columns=["input_ids", "attention_mask", "label"],
                output_all_columns=True,
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_data, batch_size=self.batch_size, shuffle=False
        )


if __name__ == "__main__":
    data_model = DataModule()
    data_model.prepare_data()
    data_model.setup()
    print(next(iter(data_model.train_dataloader()))["input_ids"].shape)


================================================
FILE: week_3_dvc/dvcfiles/trained_model.dvc
================================================
wdir: ../models
outs:
- md5: c2f5c0a1954209865b9be1945f33ed6e
  size: 17567709
  path: best-checkpoint.ckpt


================================================
FILE: week_3_dvc/experimental_notebooks/data_exploration.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (/Users/raviraja/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    }
   ],
   "source": [
    "cola_dataset = load_dataset('glue', 'cola')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 8551\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1043\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sentence', 'label', 'idx'],\n",
       "        num_rows: 1063\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cola_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8551, 1043, 1063)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset), len(val_dataset), len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\"}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence': 'The sailors rode the breeze clear of the rocks.'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0, 'label': -1, 'sentence': 'Bill whistled past the house.'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sentence': Value(dtype='string', id=None),\n",
       " 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),\n",
       " 'idx': Value(dtype='int32', id=None)}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c681f26df104422a4c21a216b351949",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [0, 1, 2, 3, 4],\n",
       " 'label': [1, 1, 1, 1, 1],\n",
       " 'sentence': [\"Our friends won't buy this analysis, let alone the next one we propose.\",\n",
       "  \"One more pseudo generalization and I'm giving up.\",\n",
       "  \"One more pseudo generalization or I'm giving up.\",\n",
       "  'The more we study verbs, the crazier they get.',\n",
       "  'Day by day the facts are getting murkier.']}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('acceptable'))[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7276a21736814e29b7df2af0bdee2dab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'idx': [18, 20, 22, 23, 25],\n",
       " 'label': [0, 0, 0, 0, 0],\n",
       " 'sentence': ['They drank the pub.',\n",
       "  'The professor talked us.',\n",
       "  'We yelled ourselves.',\n",
       "  'We yelled Harry hoarse.',\n",
       "  'Harry coughed himself.']}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.filter(lambda example: example['label'] == train_dataset.features['label'].str2int('unacceptable'))[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenizing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google/bert_uncased_L-2_H-128_A-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = cola_dataset['train']\n",
    "val_dataset = cola_dataset['validation']\n",
    "test_dataset = cola_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='google/bert_uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our friends won't buy this analysis, let alone the next one we propose.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(train_dataset[0]['sentence'])\n",
    "tokenizer(train_dataset[0]['sentence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[CLS] our friends won't buy this analysis, let alone the next one we propose. [SEP]\""
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenizer(train_dataset[0]['sentence'])['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode(examples):\n",
    "    return tokenizer(\n",
    "            examples[\"sentence\"],\n",
    "            truncation=True,\n",
    "            padding=\"max_length\",\n",
    "            max_length=512,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5205de7df394d5a800f2ee94d3c9106",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train_dataset = train_dataset.map(encode, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Formatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         ...,\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0]]),\n",
       " 'input_ids': tensor([[  101,  2256,  2814,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         [  101,  2028,  2062,  ...,     0,     0,     0],\n",
       "         ...,\n",
       "         [  101,  5965, 12808,  ...,     0,     0,     0],\n",
       "         [  101,  2198, 10948,  ...,     0,     0,     0],\n",
       "         [  101,  3021, 24471,  ...,     0,     0,     0]]),\n",
       " 'label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\n",
       "         1, 0, 0, 1, 1, 1, 1, 1])}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32])\n",
      "to
Download .txt
gitextract_g0pkv867/

├── .dvc/
│   ├── .gitignore
│   ├── config
│   └── plots/
│       ├── confusion.json
│       ├── confusion_normalized.json
│       ├── default.json
│       ├── linear.json
│       ├── scatter.json
│       └── smooth.json
├── .dvcignore
├── .github/
│   └── workflows/
│       ├── basic.yaml
│       └── build_docker_image.yaml
├── .gitignore
├── LICENSE
├── README.md
├── week_0_project_setup/
│   ├── README.md
│   ├── data.py
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_1_wandb_logging/
│   ├── README.md
│   ├── data.py
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_2_hydra_config/
│   ├── README.md
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── data.py
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_3_dvc/
│   ├── README.md
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── data.py
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── model.py
│   ├── requirements.txt
│   └── train.py
├── week_4_onnx/
│   ├── README.md
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── requirements.txt
│   ├── train.py
│   └── utils.py
├── week_5_docker/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
├── week_6_github_actions/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── parse_json.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
├── week_7_ecr/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── model.py
│   ├── parse_json.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
├── week_8_serverless/
│   ├── Dockerfile
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── config.yaml
│   │   ├── model/
│   │   │   └── default.yaml
│   │   ├── processing/
│   │   │   └── default.yaml
│   │   └── training/
│   │       └── default.yaml
│   ├── convert_model_to_onnx.py
│   ├── data.py
│   ├── docker-compose.yml
│   ├── dvcfiles/
│   │   └── trained_model.dvc
│   ├── experimental_notebooks/
│   │   └── data_exploration.ipynb
│   ├── inference.py
│   ├── inference_onnx.py
│   ├── lambda_handler.py
│   ├── model.py
│   ├── parse_json.py
│   ├── requirements.txt
│   ├── requirements_inference.txt
│   ├── train.py
│   └── utils.py
└── week_9_monitoring/
    ├── Dockerfile
    ├── README.md
    ├── app.py
    ├── configs/
    │   ├── config.yaml
    │   ├── model/
    │   │   └── default.yaml
    │   ├── processing/
    │   │   └── default.yaml
    │   └── training/
    │       └── default.yaml
    ├── convert_model_to_onnx.py
    ├── data.py
    ├── docker-compose.yml
    ├── dvcfiles/
    │   └── trained_model.dvc
    ├── experimental_notebooks/
    │   └── data_exploration.ipynb
    ├── inference.py
    ├── inference_onnx.py
    ├── lambda_handler.py
    ├── model.py
    ├── parse_json.py
    ├── requirements.txt
    ├── requirements_inference.txt
    ├── train.py
    └── utils.py
Download .txt
SYMBOL INDEX (248 symbols across 65 files)

FILE: week_0_project_setup/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", bat...
    method prepare_data (line 16) | def prepare_data(self):
    method tokenize_data (line 21) | def tokenize_data(self, example):
    method setup (line 29) | def setup(self, stage=None):
    method train_dataloader (line 42) | def train_dataloader(self):
    method val_dataloader (line 47) | def val_dataloader(self):

FILE: week_0_project_setup/inference.py
  class ColaPredictor (line 6) | class ColaPredictor:
    method __init__ (line 7) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_0_project_setup/model.py
  class ColaModel (line 9) | class ColaModel(pl.LightningModule):
    method __init__ (line 10) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 18) | def forward(self, input_ids, attention_mask):
    method training_step (line 25) | def training_step(self, batch, batch_idx):
    method validation_step (line 31) | def validation_step(self, batch, batch_idx):
    method configure_optimizers (line 40) | def configure_optimizers(self):

FILE: week_0_project_setup/train.py
  function main (line 10) | def main():

FILE: week_1_wandb_logging/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", bat...
    method prepare_data (line 16) | def prepare_data(self):
    method tokenize_data (line 21) | def tokenize_data(self, example):
    method setup (line 29) | def setup(self, stage=None):
    method train_dataloader (line 44) | def train_dataloader(self):
    method val_dataloader (line 49) | def val_dataloader(self):

FILE: week_1_wandb_logging/inference.py
  class ColaPredictor (line 6) | class ColaPredictor:
    method __init__ (line 7) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_1_wandb_logging/model.py
  class ColaModel (line 13) | class ColaModel(pl.LightningModule):
    method __init__ (line 14) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 34) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 40) | def training_step(self, batch, batch_idx):
    method validation_step (line 51) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 76) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 109) | def configure_optimizers(self):

FILE: week_1_wandb_logging/train.py
  class SamplesVisualisationLogger (line 13) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 14) | def __init__(self, datamodule):
    method on_validation_end (line 19) | def on_validation_end(self, trainer, pl_module):
  function main (line 40) | def main():

FILE: week_2_hydra_config/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(
    method prepare_data (line 22) | def prepare_data(self):
    method tokenize_data (line 27) | def tokenize_data(self, example):
    method setup (line 35) | def setup(self, stage=None):
    method train_dataloader (line 50) | def train_dataloader(self):
    method val_dataloader (line 55) | def val_dataloader(self):

FILE: week_2_hydra_config/inference.py
  class ColaPredictor (line 6) | class ColaPredictor:
    method __init__ (line 7) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_2_hydra_config/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_2_hydra_config/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_3_dvc/data.py
  class DataModule (line 8) | class DataModule(pl.LightningDataModule):
    method __init__ (line 9) | def __init__(
    method prepare_data (line 21) | def prepare_data(self):
    method tokenize_data (line 26) | def tokenize_data(self, example):
    method setup (line 34) | def setup(self, stage=None):
    method train_dataloader (line 49) | def train_dataloader(self):
    method val_dataloader (line 54) | def val_dataloader(self):

FILE: week_3_dvc/inference.py
  class ColaPredictor (line 6) | class ColaPredictor:
    method __init__ (line 7) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_3_dvc/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_3_dvc/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_4_onnx/convert_model_to_onnx.py
  function convert_model (line 14) | def convert_model(cfg):

FILE: week_4_onnx/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(
    method prepare_data (line 22) | def prepare_data(self):
    method tokenize_data (line 27) | def tokenize_data(self, example):
    method setup (line 35) | def setup(self, stage=None):
    method train_dataloader (line 50) | def train_dataloader(self):
    method val_dataloader (line 55) | def val_dataloader(self):

FILE: week_4_onnx/inference.py
  class ColaPredictor (line 7) | class ColaPredictor:
    method __init__ (line 8) | def __init__(self, model_path):
    method predict (line 18) | def predict(self, text):

FILE: week_4_onnx/inference_onnx.py
  class ColaONNXPredictor (line 9) | class ColaONNXPredictor:
    method __init__ (line 10) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_4_onnx/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_4_onnx/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_4_onnx/utils.py
  function timing (line 5) | def timing(f):

FILE: week_5_docker/app.py
  function home_page (line 8) | async def home_page():
  function get_prediction (line 13) | async def get_prediction(text: str):

FILE: week_5_docker/convert_model_to_onnx.py
  function convert_model (line 14) | def convert_model(cfg):

FILE: week_5_docker/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(
    method prepare_data (line 22) | def prepare_data(self):
    method tokenize_data (line 27) | def tokenize_data(self, example):
    method setup (line 35) | def setup(self, stage=None):
    method train_dataloader (line 50) | def train_dataloader(self):
    method val_dataloader (line 55) | def val_dataloader(self):

FILE: week_5_docker/inference.py
  class ColaPredictor (line 7) | class ColaPredictor:
    method __init__ (line 8) | def __init__(self, model_path):
    method predict (line 18) | def predict(self, text):

FILE: week_5_docker/inference_onnx.py
  class ColaONNXPredictor (line 9) | class ColaONNXPredictor:
    method __init__ (line 10) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_5_docker/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_5_docker/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_5_docker/utils.py
  function timing (line 5) | def timing(f):

FILE: week_6_github_actions/app.py
  function home_page (line 8) | async def home_page():
  function get_prediction (line 13) | async def get_prediction(text: str):

FILE: week_6_github_actions/convert_model_to_onnx.py
  function convert_model (line 14) | def convert_model(cfg):

FILE: week_6_github_actions/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(
    method prepare_data (line 22) | def prepare_data(self):
    method tokenize_data (line 27) | def tokenize_data(self, example):
    method setup (line 35) | def setup(self, stage=None):
    method train_dataloader (line 50) | def train_dataloader(self):
    method val_dataloader (line 55) | def val_dataloader(self):

FILE: week_6_github_actions/inference.py
  class ColaPredictor (line 7) | class ColaPredictor:
    method __init__ (line 8) | def __init__(self, model_path):
    method predict (line 18) | def predict(self, text):

FILE: week_6_github_actions/inference_onnx.py
  class ColaONNXPredictor (line 9) | class ColaONNXPredictor:
    method __init__ (line 10) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_6_github_actions/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_6_github_actions/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_6_github_actions/utils.py
  function timing (line 5) | def timing(f):

FILE: week_7_ecr/app.py
  function home_page (line 8) | async def home_page():
  function get_prediction (line 13) | async def get_prediction(text: str):

FILE: week_7_ecr/convert_model_to_onnx.py
  function convert_model (line 14) | def convert_model(cfg):

FILE: week_7_ecr/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(
    method prepare_data (line 22) | def prepare_data(self):
    method tokenize_data (line 27) | def tokenize_data(self, example):
    method setup (line 35) | def setup(self, stage=None):
    method train_dataloader (line 50) | def train_dataloader(self):
    method val_dataloader (line 55) | def val_dataloader(self):

FILE: week_7_ecr/inference.py
  class ColaPredictor (line 7) | class ColaPredictor:
    method __init__ (line 8) | def __init__(self, model_path):
    method predict (line 18) | def predict(self, text):

FILE: week_7_ecr/inference_onnx.py
  class ColaONNXPredictor (line 9) | class ColaONNXPredictor:
    method __init__ (line 10) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_7_ecr/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_7_ecr/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_7_ecr/utils.py
  function timing (line 5) | def timing(f):

FILE: week_8_serverless/app.py
  function home_page (line 8) | async def home_page():
  function get_prediction (line 13) | async def get_prediction(text: str):

FILE: week_8_serverless/convert_model_to_onnx.py
  function convert_model (line 14) | def convert_model(cfg):

FILE: week_8_serverless/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(
    method prepare_data (line 22) | def prepare_data(self):
    method tokenize_data (line 27) | def tokenize_data(self, example):
    method setup (line 35) | def setup(self, stage=None):
    method train_dataloader (line 50) | def train_dataloader(self):
    method val_dataloader (line 55) | def val_dataloader(self):

FILE: week_8_serverless/inference.py
  class ColaPredictor (line 7) | class ColaPredictor:
    method __init__ (line 8) | def __init__(self, model_path):
    method predict (line 18) | def predict(self, text):

FILE: week_8_serverless/inference_onnx.py
  class ColaONNXPredictor (line 9) | class ColaONNXPredictor:
    method __init__ (line 10) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_8_serverless/lambda_handler.py
  function lambda_handler (line 10) | def lambda_handler(event, context):

FILE: week_8_serverless/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_8_serverless/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_8_serverless/utils.py
  function timing (line 5) | def timing(f):

FILE: week_9_monitoring/app.py
  function home_page (line 8) | async def home_page():
  function get_prediction (line 13) | async def get_prediction(text: str):

FILE: week_9_monitoring/convert_model_to_onnx.py
  function convert_model (line 14) | def convert_model(cfg):

FILE: week_9_monitoring/data.py
  class DataModule (line 9) | class DataModule(pl.LightningDataModule):
    method __init__ (line 10) | def __init__(
    method prepare_data (line 22) | def prepare_data(self):
    method tokenize_data (line 27) | def tokenize_data(self, example):
    method setup (line 35) | def setup(self, stage=None):
    method train_dataloader (line 50) | def train_dataloader(self):
    method val_dataloader (line 55) | def val_dataloader(self):

FILE: week_9_monitoring/inference.py
  class ColaPredictor (line 7) | class ColaPredictor:
    method __init__ (line 8) | def __init__(self, model_path):
    method predict (line 18) | def predict(self, text):

FILE: week_9_monitoring/inference_onnx.py
  class ColaONNXPredictor (line 9) | class ColaONNXPredictor:
    method __init__ (line 10) | def __init__(self, model_path):
    method predict (line 16) | def predict(self, text):

FILE: week_9_monitoring/lambda_handler.py
  function lambda_handler (line 17) | def lambda_handler(event, context):

FILE: week_9_monitoring/model.py
  class ColaModel (line 15) | class ColaModel(pl.LightningModule):
    method __init__ (line 16) | def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=...
    method forward (line 36) | def forward(self, input_ids, attention_mask, labels=None):
    method training_step (line 42) | def training_step(self, batch, batch_idx):
    method validation_step (line 53) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 78) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 111) | def configure_optimizers(self):

FILE: week_9_monitoring/train.py
  class SamplesVisualisationLogger (line 19) | class SamplesVisualisationLogger(pl.Callback):
    method __init__ (line 20) | def __init__(self, datamodule):
    method on_validation_end (line 25) | def on_validation_end(self, trainer, pl_module):
  function main (line 47) | def main(cfg):

FILE: week_9_monitoring/utils.py
  function timing (line 5) | def timing(f):
Condensed preview — 167 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (560K chars).
[
  {
    "path": ".dvc/.gitignore",
    "chars": 26,
    "preview": "/config.local\n/tmp\n/cache\n"
  },
  {
    "path": ".dvc/config",
    "chars": 173,
    "preview": "[core]\n    remote = model-store\n['remote \"storage\"']\n    url = gdrive://19JK5AFbqOBlrFVwDHjTrf9uvQFtS0954\n['remote \"mode"
  },
  {
    "path": ".dvc/plots/confusion.json",
    "chars": 3000,
    "preview": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n"
  },
  {
    "path": ".dvc/plots/confusion_normalized.json",
    "chars": 3145,
    "preview": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n"
  },
  {
    "path": ".dvc/plots/default.json",
    "chars": 714,
    "preview": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n"
  },
  {
    "path": ".dvc/plots/linear.json",
    "chars": 3654,
    "preview": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n"
  },
  {
    "path": ".dvc/plots/scatter.json",
    "chars": 3266,
    "preview": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n"
  },
  {
    "path": ".dvc/plots/smooth.json",
    "chars": 889,
    "preview": "{\n    \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\",\n    \"data\": {\n        \"values\": \"<DVC_METRIC_DATA>\"\n"
  },
  {
    "path": ".dvcignore",
    "chars": 139,
    "preview": "# Add patterns of files dvc should ignore, which could improve\n# the performance. Learn more at\n# https://dvc.org/doc/us"
  },
  {
    "path": ".github/workflows/basic.yaml",
    "chars": 866,
    "preview": "name: GitHub Actions Basic Flow\non: [push]\njobs:\n  Basic-workflow:\n    runs-on: ubuntu-latest\n    steps:\n      - name: B"
  },
  {
    "path": ".github/workflows/build_docker_image.yaml",
    "chars": 1333,
    "preview": "name: Create Docker Container\n\non: [push]\n\njobs:\n  mlops-container:\n    runs-on: ubuntu-latest\n    defaults:\n      run:\n"
  },
  {
    "path": ".gitignore",
    "chars": 1885,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2021 raviraja\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README.md",
    "chars": 12271,
    "preview": "# MLOps-Basics\n\n > There is nothing magic about magic. The magician merely understands something simple which doesn’t ap"
  },
  {
    "path": "week_0_project_setup/README.md",
    "chars": 1026,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_0_project_setup/data.py",
    "chars": 1850,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_0_project_setup/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_0_project_setup/inference.py",
    "chars": 1112,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model"
  },
  {
    "path": "week_0_project_setup/model.py",
    "chars": 1550,
    "preview": "import torch\nimport torch.nn as nn\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom transformers impo"
  },
  {
    "path": "week_0_project_setup/requirements.txt",
    "chars": 82,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2"
  },
  {
    "path": "week_0_project_setup/train.py",
    "chars": 916,
    "preview": "import torch\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightni"
  },
  {
    "path": "week_1_wandb_logging/README.md",
    "chars": 1387,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_1_wandb_logging/data.py",
    "chars": 1908,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_1_wandb_logging/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_1_wandb_logging/inference.py",
    "chars": 1112,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model"
  },
  {
    "path": "week_1_wandb_logging/model.py",
    "chars": 4670,
    "preview": "import torch\nimport wandb\nimport numpy as np\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom transformers import"
  },
  {
    "path": "week_1_wandb_logging/requirements.txt",
    "chars": 120,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_1_wandb_logging/train.py",
    "chars": 2011,
    "preview": "import torch\nimport wandb\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import Mod"
  },
  {
    "path": "week_2_hydra_config/README.md",
    "chars": 1387,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_2_hydra_config/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_2_hydra_config/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_2_hydra_config/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_2_hydra_config/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_2_hydra_config/data.py",
    "chars": 2012,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_2_hydra_config/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_2_hydra_config/inference.py",
    "chars": 1114,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model"
  },
  {
    "path": "week_2_hydra_config/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_2_hydra_config/requirements.txt",
    "chars": 156,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_2_hydra_config/train.py",
    "chars": 2591,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_3_dvc/README.md",
    "chars": 1488,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_3_dvc/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_3_dvc/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_3_dvc/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_3_dvc/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_3_dvc/data.py",
    "chars": 1996,
    "preview": "import torch\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer\n\n\n"
  },
  {
    "path": "week_3_dvc/dvcfiles/trained_model.dvc",
    "chars": 108,
    "preview": "wdir: ../models\nouts:\n- md5: c2f5c0a1954209865b9be1945f33ed6e\n  size: 17567709\n  path: best-checkpoint.ckpt\n"
  },
  {
    "path": "week_3_dvc/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_3_dvc/inference.py",
    "chars": 1114,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\n\n\nclass ColaPredictor:\n    def __init__(self, model"
  },
  {
    "path": "week_3_dvc/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_3_dvc/requirements.txt",
    "chars": 156,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_3_dvc/train.py",
    "chars": 2651,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_4_onnx/README.md",
    "chars": 1672,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_4_onnx/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_4_onnx/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_4_onnx/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_4_onnx/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_4_onnx/convert_model_to_onnx.py",
    "chars": 1849,
    "preview": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom da"
  },
  {
    "path": "week_4_onnx/data.py",
    "chars": 2012,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_4_onnx/dvcfiles/trained_model.dvc",
    "chars": 108,
    "preview": "wdir: ../models\nouts:\n- md5: c2f5c0a1954209865b9be1945f33ed6e\n  size: 17567709\n  path: best-checkpoint.ckpt\n"
  },
  {
    "path": "week_4_onnx/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_4_onnx/inference.py",
    "chars": 1273,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n   "
  },
  {
    "path": "week_4_onnx/inference_onnx.py",
    "chars": 1256,
    "preview": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils i"
  },
  {
    "path": "week_4_onnx/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_4_onnx/requirements.txt",
    "chars": 156,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_4_onnx/train.py",
    "chars": 2651,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_4_onnx/utils.py",
    "chars": 414,
    "preview": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n  "
  },
  {
    "path": "week_5_docker/Dockerfile",
    "chars": 230,
    "preview": "FROM huggingface/transformers-pytorch-cpu:latest\nCOPY ./ /app\nWORKDIR /app\nRUN pip install -r requirements_prod.txt\nENV "
  },
  {
    "path": "week_5_docker/README.md",
    "chars": 2068,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_5_docker/app.py",
    "chars": 364,
    "preview": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredic"
  },
  {
    "path": "week_5_docker/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_5_docker/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_5_docker/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_5_docker/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_5_docker/convert_model_to_onnx.py",
    "chars": 1849,
    "preview": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom da"
  },
  {
    "path": "week_5_docker/data.py",
    "chars": 2012,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_5_docker/docker-compose.yml",
    "chars": 146,
    "preview": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:"
  },
  {
    "path": "week_5_docker/dvcfiles/trained_model.dvc",
    "chars": 108,
    "preview": "wdir: ../models\nouts:\n- md5: c2f5c0a1954209865b9be1945f33ed6e\n  size: 17567709\n  path: best-checkpoint.ckpt\n"
  },
  {
    "path": "week_5_docker/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_5_docker/inference.py",
    "chars": 1273,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n   "
  },
  {
    "path": "week_5_docker/inference_onnx.py",
    "chars": 1290,
    "preview": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils i"
  },
  {
    "path": "week_5_docker/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_5_docker/requirements.txt",
    "chars": 173,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_5_docker/requirements_inference.txt",
    "chars": 127,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nu"
  },
  {
    "path": "week_5_docker/train.py",
    "chars": 2651,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_5_docker/utils.py",
    "chars": 414,
    "preview": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n  "
  },
  {
    "path": "week_6_github_actions/Dockerfile",
    "chars": 696,
    "preview": "FROM huggingface/transformers-pytorch-cpu:latest\n\nCOPY ./ /app\nWORKDIR /app\n\n# install requirements\nRUN pip install \"dvc"
  },
  {
    "path": "week_6_github_actions/README.md",
    "chars": 2538,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_6_github_actions/app.py",
    "chars": 364,
    "preview": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredic"
  },
  {
    "path": "week_6_github_actions/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_6_github_actions/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_6_github_actions/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_6_github_actions/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_6_github_actions/convert_model_to_onnx.py",
    "chars": 1849,
    "preview": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom da"
  },
  {
    "path": "week_6_github_actions/data.py",
    "chars": 2012,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_6_github_actions/docker-compose.yml",
    "chars": 146,
    "preview": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:"
  },
  {
    "path": "week_6_github_actions/dvcfiles/trained_model.dvc",
    "chars": 98,
    "preview": "wdir: ../models\nouts:\n- md5: d82b8390fa2f09b121de4abfa094a7a9\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_6_github_actions/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_6_github_actions/inference.py",
    "chars": 1273,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n   "
  },
  {
    "path": "week_6_github_actions/inference_onnx.py",
    "chars": 1290,
    "preview": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils i"
  },
  {
    "path": "week_6_github_actions/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_6_github_actions/parse_json.py",
    "chars": 211,
    "preview": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print("
  },
  {
    "path": "week_6_github_actions/requirements.txt",
    "chars": 173,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_6_github_actions/requirements_inference.txt",
    "chars": 130,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nu"
  },
  {
    "path": "week_6_github_actions/train.py",
    "chars": 2651,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_6_github_actions/utils.py",
    "chars": 414,
    "preview": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n  "
  },
  {
    "path": "week_7_ecr/Dockerfile",
    "chars": 716,
    "preview": "FROM huggingface/transformers-pytorch-cpu:latest\n\nCOPY ./ /app\nWORKDIR /app\n\nARG AWS_ACCESS_KEY_ID\nARG AWS_SECRET_ACCESS"
  },
  {
    "path": "week_7_ecr/README.md",
    "chars": 3687,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_7_ecr/app.py",
    "chars": 364,
    "preview": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredic"
  },
  {
    "path": "week_7_ecr/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_7_ecr/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_7_ecr/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_7_ecr/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_7_ecr/convert_model_to_onnx.py",
    "chars": 1849,
    "preview": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom da"
  },
  {
    "path": "week_7_ecr/data.py",
    "chars": 2012,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_7_ecr/docker-compose.yml",
    "chars": 146,
    "preview": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:"
  },
  {
    "path": "week_7_ecr/dvcfiles/trained_model.dvc",
    "chars": 98,
    "preview": "wdir: ../models\nouts:\n- md5: 02f3b0034769ba45d758ad1bb9de33a3\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_7_ecr/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_7_ecr/inference.py",
    "chars": 1273,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n   "
  },
  {
    "path": "week_7_ecr/inference_onnx.py",
    "chars": 1290,
    "preview": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils i"
  },
  {
    "path": "week_7_ecr/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_7_ecr/parse_json.py",
    "chars": 211,
    "preview": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print("
  },
  {
    "path": "week_7_ecr/requirements.txt",
    "chars": 173,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_7_ecr/requirements_inference.txt",
    "chars": 130,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nu"
  },
  {
    "path": "week_7_ecr/train.py",
    "chars": 2651,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_7_ecr/utils.py",
    "chars": 414,
    "preview": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n  "
  },
  {
    "path": "week_8_serverless/Dockerfile",
    "chars": 858,
    "preview": "FROM amazon/aws-lambda-python\n\nARG AWS_ACCESS_KEY_ID\nARG AWS_SECRET_ACCESS_KEY\nARG MODEL_DIR=./models\nRUN mkdir $MODEL_D"
  },
  {
    "path": "week_8_serverless/README.md",
    "chars": 4029,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_8_serverless/app.py",
    "chars": 364,
    "preview": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredic"
  },
  {
    "path": "week_8_serverless/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_8_serverless/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_8_serverless/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_8_serverless/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_8_serverless/convert_model_to_onnx.py",
    "chars": 1849,
    "preview": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom da"
  },
  {
    "path": "week_8_serverless/data.py",
    "chars": 2012,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_8_serverless/docker-compose.yml",
    "chars": 146,
    "preview": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:"
  },
  {
    "path": "week_8_serverless/dvcfiles/trained_model.dvc",
    "chars": 98,
    "preview": "wdir: ../models\nouts:\n- md5: 02f3b0034769ba45d758ad1bb9de33a3\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_8_serverless/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_8_serverless/inference.py",
    "chars": 1273,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n   "
  },
  {
    "path": "week_8_serverless/inference_onnx.py",
    "chars": 1290,
    "preview": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils i"
  },
  {
    "path": "week_8_serverless/lambda_handler.py",
    "chars": 726,
    "preview": "\"\"\"\nLambda wrapper\n\"\"\"\n\nimport json\nfrom inference_onnx import ColaONNXPredictor\n\ninferencing_instance = ColaONNXPredict"
  },
  {
    "path": "week_8_serverless/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_8_serverless/parse_json.py",
    "chars": 211,
    "preview": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print("
  },
  {
    "path": "week_8_serverless/requirements.txt",
    "chars": 173,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_8_serverless/requirements_inference.txt",
    "chars": 169,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nu"
  },
  {
    "path": "week_8_serverless/train.py",
    "chars": 2651,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_8_serverless/utils.py",
    "chars": 414,
    "preview": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n  "
  },
  {
    "path": "week_9_monitoring/Dockerfile",
    "chars": 858,
    "preview": "FROM amazon/aws-lambda-python\n\nARG AWS_ACCESS_KEY_ID\nARG AWS_SECRET_ACCESS_KEY\nARG MODEL_DIR=./models\nRUN mkdir $MODEL_D"
  },
  {
    "path": "week_9_monitoring/README.md",
    "chars": 4247,
    "preview": "\n**Note: The purpose of the project to explore the libraries and learn how to use them. Not to build a SOTA model.**\n\n##"
  },
  {
    "path": "week_9_monitoring/app.py",
    "chars": 364,
    "preview": "from fastapi import FastAPI\nfrom inference_onnx import ColaONNXPredictor\napp = FastAPI(title=\"MLOps Basics App\")\n\npredic"
  },
  {
    "path": "week_9_monitoring/configs/config.yaml",
    "chars": 158,
    "preview": "defaults:\n  - model: default\n  - processing: default\n  - training: default\n  - override hydra/job_logging: colorlog\n  - "
  },
  {
    "path": "week_9_monitoring/configs/model/default.yaml",
    "chars": 185,
    "preview": "name: google/bert_uncased_L-2_H-128_A-2             # model used for training the classifier\ntokenizer: google/bert_unca"
  },
  {
    "path": "week_9_monitoring/configs/processing/default.yaml",
    "chars": 30,
    "preview": "batch_size: 64\nmax_length: 128"
  },
  {
    "path": "week_9_monitoring/configs/training/default.yaml",
    "chars": 132,
    "preview": "max_epochs: 1\nlog_every_n_steps: 10\ndeterministic: true\nlimit_train_batches: 0.25\nlimit_val_batches: ${training.limit_tr"
  },
  {
    "path": "week_9_monitoring/convert_model_to_onnx.py",
    "chars": 1849,
    "preview": "import torch\nimport hydra\nimport logging\n\nfrom omegaconf.omegaconf import OmegaConf\n\nfrom model import ColaModel\nfrom da"
  },
  {
    "path": "week_9_monitoring/data.py",
    "chars": 2012,
    "preview": "import torch\nimport datasets\nimport pytorch_lightning as pl\n\nfrom datasets import load_dataset\nfrom transformers import "
  },
  {
    "path": "week_9_monitoring/docker-compose.yml",
    "chars": 146,
    "preview": "version: \"3\"\nservices:\n    prediction_api:\n        build: .\n        container_name: \"inference_container\"\n        ports:"
  },
  {
    "path": "week_9_monitoring/dvcfiles/trained_model.dvc",
    "chars": 98,
    "preview": "wdir: ../models\nouts:\n- md5: 02f3b0034769ba45d758ad1bb9de33a3\n  size: 17562590\n  path: model.onnx\n"
  },
  {
    "path": "week_9_monitoring/experimental_notebooks/data_exploration.ipynb",
    "chars": 32123,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports\"\n   ]\n  },\n  {\n   \"cell_"
  },
  {
    "path": "week_9_monitoring/inference.py",
    "chars": 1273,
    "preview": "import torch\nfrom model import ColaModel\nfrom data import DataModule\nfrom utils import timing\n\n\nclass ColaPredictor:\n   "
  },
  {
    "path": "week_9_monitoring/inference_onnx.py",
    "chars": 1388,
    "preview": "import numpy as np\nimport onnxruntime as ort\nfrom scipy.special import softmax\n\nfrom data import DataModule\nfrom utils i"
  },
  {
    "path": "week_9_monitoring/lambda_handler.py",
    "chars": 1027,
    "preview": "\"\"\"\nLambda wrapper\n\"\"\"\n\nimport json\nimport logging\nfrom inference_onnx import ColaONNXPredictor\n\nlogging.basicConfig()\nl"
  },
  {
    "path": "week_9_monitoring/model.py",
    "chars": 4727,
    "preview": "import torch\nimport wandb\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torchmetrics\nimport pytorch_lightni"
  },
  {
    "path": "week_9_monitoring/parse_json.py",
    "chars": 211,
    "preview": "import json\n\nwith open('creds.txt') as f:\n\tdata = f.read()\n\nprint(data)\n# data = json.loads(data, strict=False)\n# print("
  },
  {
    "path": "week_9_monitoring/requirements.txt",
    "chars": 173,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\ntransformers==4.5.1\nscikit-learn==0.24.2\nwandb\ntorchmetrics\nmatplotlib\nseaborn"
  },
  {
    "path": "week_9_monitoring/requirements_inference.txt",
    "chars": 169,
    "preview": "pytorch-lightning==1.2.10\ndatasets==1.6.2\nscikit-learn==0.24.2\nhydra-core\nomegaconf\nhydra_colorlog\nonnxruntime\nfastapi\nu"
  },
  {
    "path": "week_9_monitoring/train.py",
    "chars": 2651,
    "preview": "import torch\nimport hydra\nimport wandb\nimport logging\n\nimport pandas as pd\nimport pytorch_lightning as pl\nfrom omegaconf"
  },
  {
    "path": "week_9_monitoring/utils.py",
    "chars": 414,
    "preview": "import time\nfrom functools import wraps\n\n\ndef timing(f):\n    \"\"\"Decorator for timing functions\n    Usage:\n    @timing\n  "
  }
]

About this extraction

This page contains the full source code of the graviraja/MLOps-Basics GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 167 files (502.1 KB), approximately 162.0k tokens, and a symbol index with 248 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!