Full Code of RyanWangZf/transtab for AI

main fdb34cf38abd cached
49 files
510.6 KB
177.7k tokens
115 symbols
1 requests
Download .txt
Showing preview only (535K chars total). Download the full file or copy to clipboard to get everything.
Repository: RyanWangZf/transtab
Branch: main
Commit: fdb34cf38abd
Files: 49
Total size: 510.6 KB

Directory structure:
gitextract_y454drjn/

├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── blog/
│   └── README.md
├── docs/
│   ├── Makefile
│   ├── make.bat
│   ├── requirements.txt
│   ├── source/
│   │   ├── about.rst
│   │   ├── conf.py
│   │   ├── data_preparation.rst
│   │   ├── example_encode.rst
│   │   ├── example_pretrain.rst
│   │   ├── example_transfer.rst
│   │   ├── fast_train.rst
│   │   ├── index.rst
│   │   ├── install.rst
│   │   ├── main_func.rst
│   │   ├── models.rst
│   │   ├── transtab.basemodel.rst
│   │   ├── transtab.build_classifier.rst
│   │   ├── transtab.build_contrastive_learner.rst
│   │   ├── transtab.build_encoder.rst
│   │   ├── transtab.build_extractor.rst
│   │   ├── transtab.classifier.rst
│   │   ├── transtab.contrastive.rst
│   │   ├── transtab.load_data.rst
│   │   ├── transtab.predict.rst
│   │   └── transtab.train.rst
│   └── sphinx-commands.txt
├── examples/
│   ├── contrastive_learning.ipynb
│   ├── fast_train.ipynb
│   ├── table_embedding.ipynb
│   ├── transfer_learning.ipynb
│   └── transfer_learning_regressor.ipynb
├── pypi_build_commands.txt
├── requirements.txt
├── setup.py
└── transtab/
    ├── __init__.py
    ├── constants.py
    ├── dataset.py
    ├── evaluator.py
    ├── modeling_transtab.py
    ├── tokenizer/
    │   ├── special_tokens_map.json
    │   ├── tokenizer_config.json
    │   └── vocab.txt
    ├── trainer.py
    ├── trainer_utils.py
    └── transtab.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# checkpoint
/ckpt
/checkpoint

# MAC
.DS_Store


================================================
FILE: .readthedocs.yaml
================================================
# Read the Docs configuration file for Sphinx projects
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

# Required
version: 2

# Set the OS, Python version and other tools you might need
build:
  os: ubuntu-22.04
  tools:
    python: "3.9"
    # You can also specify other tool versions:
    # nodejs: "20"
    # rust: "1.70"
    # golang: "1.20"

# Build documentation in the "docs/" directory with Sphinx
sphinx:
  configuration: docs/source/conf.py
  # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs
  # builder: "dirhtml"
  # Fail on all warnings to avoid broken references
  # fail_on_warning: true

# Optionally build your docs in additional formats such as PDF and ePub
# formats:
#   - pdf
#   - epub

# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
  install:
    - requirements: docs/requirements.txt

================================================
FILE: LICENSE
================================================
BSD 2-Clause License

Copyright (c) 2022, Zifeng
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# TransTab: A flexible transferable tabular learning framework [[arxiv]](https://arxiv.org/pdf/2205.09328.pdf)


[![PyPI version](https://badge.fury.io/py/transtab.svg)](https://badge.fury.io/py/transtab)
[![Documentation Status](https://readthedocs.org/projects/transtab/badge/?version=latest)](https://transtab.readthedocs.io/en/latest/?badge=latest)
[![License](https://img.shields.io/badge/License-BSD_2--Clause-orange.svg)](https://opensource.org/licenses/BSD-2-Clause)
![GitHub Repo stars](https://img.shields.io/github/stars/ryanwangzf/transtab)
![GitHub Repo forks](https://img.shields.io/github/forks/ryanwangzf/transtab)
[![Downloads](https://pepy.tech/badge/transtab)](https://pepy.tech/project/transtab)
[![Downloads](https://pepy.tech/badge/transtab/month)](https://pepy.tech/project/transtab)


Document is available at https://transtab.readthedocs.io/en/latest/index.html.

Paper is available at https://arxiv.org/pdf/2205.09328.pdf.

5 min blog to understand TransTab at [realsunlab.medium.com](https://realsunlab.medium.com/transtab-learning-transferable-tabular-transformers-across-tables-1e34eec161b8)!

### News!
- [03/12/25] Version `0.0.7` with `TransTabRegressor` available for regression. Thanks @yuxinchenNU.

- [05/04/23] Check the version `0.0.5` of `TransTab`!

- [01/04/23] Check the version `0.0.3` of `TransTab`!

- [12/03/22] Check out our [[blog]](https://realsunlab.medium.com/transtab-learning-transferable-tabular-transformers-across-tables-1e34eec161b8) for a quick understanding of TransTab!

- [08/31/22] `0.0.2` Support encode tabular inputs into embeddings directly. An example is provided [here](examples/table_embedding.ipynb). Several bugs are fixed.

## TODO

- [x] Table embedding.

- [x] Add regression support.

- [ ] Add support to direct process table with missing values.


### Features
This repository provides the python package `transtab` for flexible tabular prediction model. The basic usage of `transtab` can be done in a couple of lines!

```python
import transtab

# load dataset by specifying dataset name
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
     = transtab.load_data('credit-g')

# build classifier
model = transtab.build_classifier(cat_cols, num_cols, bin_cols)

# build regressor
# model = transtab.build_regressor(cat_cols, num_cols, bin_cols)

# start training
transtab.train(model, trainset, valset, **training_arguments)

# make predictions, df_x is a pd.DataFrame with shape (n, d)
# return the predictions ypred with shape (n, 1) if binary classification;
# (n, n_class) if multiclass classification.
ypred = transtab.predict(model, df_x)
```

It's easy, isn't it?



## How to install

First, download the right ``pytorch`` version following the guide on https://pytorch.org/get-started/locally/.

~~Then try to install from pypi directly:~~ [Feb 2025: pypi version is not maintained, please try to install from github instead]

~~or~~

```bash
pip install git+https://github.com/RyanWangZf/transtab.git
```



Please refer to for [more guidance on installation](https://transtab.readthedocs.io/en/latest/install.html) and troubleshooting.



## Transfer learning across tables

A novel feature of `transtab` is its ability to learn from multiple distinct tables. It is easy to trigger the training like

```python
# load the pretrained transtab model
model = transtab.build_classifier(checkpoint='./ckpt')

# load a new tabular dataset
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
     = transtab.load_data('credit-approval')

# update categorical/numerical/binary column map of the loaded model
model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols})

# then we just trigger the training on the new data
transtab.train(model, trainset, valset, **training_arguments)
```



## Contrastive pretraining on multiple tables

We can also conduct contrastive pretraining on multiple distinct tables like

```python
# load from multiple tabular datasets
dataname_list = ['credit-g', 'credit-approval']
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
     = transtab.load_data(dataname_list)

# build contrastive learner, set supervised=True for supervised VPCL
model, collate_fn = transtab.build_contrastive_learner(
    cat_cols, num_cols, bin_cols, supervised=True)

# start contrastive pretraining training
transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)
```



## Citation

If you find this package useful, please consider citing the following paper:

```latex
@inproceedings{wang2022transtab,
  title={TransTab: Learning Transferable Tabular Transformers Across Tables},
  author={Wang, Zifeng and Sun, Jimeng},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}
```


================================================
FILE: blog/README.md
================================================
# NeurIPS'22 | How to perform transfer learning and zero-shot learning on tabular data?

> This is our paper accepted by NeurIPS'22 with ratings 7/7/7, where we work on pretraining, transfer learning, and zero-shot learning on the tabular prediction task. The following are the links for this article and the codes.

Paper: [TransTab: Learning Transferable Tabular Transformers Across Tables](https://arxiv.org/pdf/2205.09328.pdf)

Code: [TransTab-github](https://github.com/RyanWangZf/transtab)

Doc: [Transtab-doc](https://transtab.readthedocs.io/en/latest/)

---



## Tabular learning was not flexible

In this article, we refer the term *tabular learning* to the predictive task on the tabular data. For instance, we might know Kaggle competitions, where a lot of competitions are based on tabular data, e.g., house price prediction, credit fault detection, CTR prediction, etc. Basically, this type of task is on predicting the target label through a couple of features, just like in the following table

| index | feature A | feature B | feature C | label |
| ----- | --------- | --------- | --------- | ----- |
| 0     | $x_1$     | $x_2$     | $x_3$     | $y$   |

one might take a linear regression to solve this problem as
$$
y = ax_1 + b x_2 + c x_3 +d.
$$
Compared with images and texts, tables are usually more frequently used in industrial applications. As we all know, recently there emerged the *pretrain+finetune* paradigm in the deep learning area, especially flourished in computer vision (CV) and natural language processing (NLP).



<figure>
<img src = "figure/fig1.png">
<figcaption align = "center"> 
<b>Figure 1:</b> CV or NLP models usually share the same basic input unit, i.e., pixel for images and word/token for texts. However, tabular models only accept a fixed-structure table: the train and test tables should *always* have equal column sets, which prevents us from transfer learning or zero-shot learning on tabular data.
</figcaption>
</figure>


In CV & NLP, pretrained models like BERT, and ViT have become the strong baseline for almost all tasks. By contrast, in the tabular learning domain, we usually encounter the case "xgboost is all you need". GBDT models can achieve competent performances with less effort on data preprocessing and hyperparameter tuning than deep learning-based methods. In this circumstance, a lot of researchers have started to think about how we outperform GBDT using deep learning, especially leveraging the power of deep learning on big multi-sourced data.



## Recent efforts on transfer learning for tabular learning

Of course, there have been some efforts on transfer learning for deep learning-based tabular learning. For example, VIME [[1]](#1), SCARF [[2]](#2), and SubTab [[3]](#3) all employ self-supervision for tabular learning. The common self-supervision can be categorized as *generative* and *discriminative* learning. For the first venue, we mask several cells in the table and ask the model to recover the missing values; for the second, we create positive samples by deleting or replacing cells.

Nonetheless, they hardly apply to real application cases: all apply to fixed-structure tables. We do not have a large table without labels, instead, we often have multiple heterogenous labeled labels. The core challenge is how to leverage as much labeled data as possible and get rid of heavy data preprocessing and missing value imputation.

The nature of only receiving fixed-structure tables causes all existing tabular methods to be incapable of dealing with pretraining on multi-sourced tables. Once there is a minor change in the table's structure, e.g., a column named *age* changed to *ages*, the pretrained model becomes useless. And we need to roll back to the process of *data processing* $\to$ *feature engineering* $\to$ *model training*, which is costly in terms of time and money.

Therefore, we ask, if it is possible to propose a tabular model that encodes **arbitrary** input tables needless of any adaptions?



### Tabular learning is flexible

In fact, if we look back on the tabular data, we shall identify the column names are rich in semantics, which were long neglected by previous methods. 

| index | gender | age  | is_citizen |
| ----- | ------ | ---- | ---------- |
| 0     | male   | 25   | 0          |

In this example, we include three common types of features: *categorical*, *numerical*, and *binary*.

We argue that interpreting features considering column names is necessary. We know *25* under the column *age* means 25 years old instead of 25 km or 25 kg. We know *0* under the column *is_citizen* means the person is not a citizen instead of is not anything else. Previous methods drop column names and enforce the model to learn semantics from the raw features for decision-making, which is easy to implement but not transferable.

On the contrary, we ask why not just explicitly include the column names in the modeling. Surprisingly, we do not find any prior arts doing that in tabular learning.

Formally, we process three types of features through

- For categorical: we concatenate column names and features, i.e., *gender is male*.
- For numerical, we tokenize and embed column names, then multiply the column embeddings with the feature value.
- For bool: we tokenize and embed column names, they decide if pass this embedding to the encoder based on feature. If 0, then we drop this embedding.



<figure>
<img src = "figure/fig2.png">
<figcaption align = "center"> 
<b>Figure 2:</b> The input feature processing module of *TransTab*.
</figcaption>
</figure>



With this processing module, we can linearize, tokenize, and embed any tabular data, which serves as the inputs for the encoder and the predictor.



## Pretraining for TransTab

Thanks to its flexibility, *TransTab* is capable of learning across multiple heterogeneous tables. However, it is nontrivial to design an appropriate pretraining algorithm for it.



<figure>
<img src = "figure/fig3.png">
<figcaption align = "center"> 
<b>Figure 3:</b> Learning across tables using naive supervised learning is harmful for representation learning.
</figcaption>
</figure>



The most straightforward way is illustrated as above: we train a shared backbone encoder plus task-specific  classifiers across tabular datasets. Nevertheless, we soon find this paradigm is suboptimal. The flaw comes from the heterogeneity of the target labels: two datasets might have opposite definition of labels.

Accounting for this issue, we propose a novel **supervised contrastive learning** approach, namely **vertical partition contrastive learning (VPCL)** in this paper.



<figure>
<img src = "figure/fig4.png" width="80%">
<figcaption align = "center"> 
<b>Figure 4:</b> The proposed vertical partition contrastive learning (VPCL) approach for pretraining TransTab in our paper.
</figcaption>
</figure>

Its principle is:

- We split each raw (sample) into several partitions vertically, each partition is a sample for contrastive learning.
- The partition comes from the same-label raw are positive, and vice versa.

VPCL has the following merits:

- It significantly expand the number of pairs for contrastive learning.
- It is much more efficient and robust because it does not add additional task-specific classifiers.



## Which new tasks that TransTab can solve?

Thanks to the flexibility of TransTab, it now handles many new tasks.



<figure>
<img src = "figure/fig5.png">
<figcaption align = "center"> 
<b>Figure 5:</b> The new tasks that are amenable to TransTab.
</figcaption>
</figure>



- Learning across multiple labeled datasets (share the same label) based on supervised learning, and finetuned on each specific dataset.
- Learning from an incremental set of features and data, which usually originates from the updated measurements over time.
- Pretrained on multiple labeled/unlabeled datasets (can have distinct labels) based on supervised VPCL, and finetuned on each dataset.
- Learning from multiple labeled datasets (share the same label) based on supervised learning, and making predictions for brand new data without any further parameter updates.



## Some experiment results

For the complete experiment results, please refer to [our paper](https://arxiv.org/pdf/2205.09328.pdf). Here we tell two interesting findings.



### Pretraining

<figure>
<img src = "figure/fig6.png">
<figcaption align = "center"> 
<b>Figure 6:</b> Experiment results of the pretraining+finetuning performances of TransTab.
</figcaption>
</figure>

The above figure illustrates the average performance (AUC) on multiple datasets. Left: on clinical  trial patient outcome prediction datasets. Right: on many open datasets. The red dotted line indicates the naive supervised learning performance. X-axis is the number of partitions made for VPCL.

We find:

- Supervised VPCL generally improves predictive performances.
- It is not an universally optimal number of partitions for VPCL.
- Compared with open datasets, the pretraining on the left introduces much more improvements. That implies that it is still crucial to transfer knowledge from datasets coming from the similar domain. While the open datasets are very heterogeneous.



### Zero-shot prediction



<figure>
<img src = "figure/fig7.png">
<figcaption align = "center"> 
<b>Figure 7:</b> Experiment results of the zero-shot learning performances of TransTab.
</figcaption>
</figure>

The above figures demonstrate the zero-shot prediction performances of TransTab. We split one dataset into two parts and vary the overlap ratio of their column sets: from 0% to 100%. We find:

- TransTab can even make reasonable predictions when there is **no column overlapping** between the train and test data, which is really amazing.
- When the overlap ratio increases, we witness better performances, which is reasonable.





## Use TransTab based on our package

We opensourced our package on [github](https://github.com/RyanWangZf/transtab) with the [documentations](https://transtab.readthedocs.io/en/latest/). It can be downloaded through

```shell
pip install git+https://github.com/RyanWangZf/transtab.git
```



And it is rather easy to use it in tabular prediction tasks on multiple distinct tables:

```python
import transtab

# load multiple datasets by passing a list of data names
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
    = transtab.load_data(['credit-g','credit-approval'])

# build transtab classifier model
model = transtab.build_classifier(cat_cols, num_cols, bin_cols)

# specify training arguments, take validation loss for early stopping
training_arguments = {
    'num_epoch':5,
    'eval_metric':'val_loss',
    'eval_less_is_better':True,
    'output_dir':'./checkpoint'
    }

# start training
transtab.train(model, trainset, valset[0], **training_arguments)
```



For pretraining based on VPCL, we have

```python
import transtab

# load multiple datasets by passing a list of data names
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
    = transtab.load_data(['credit-g','credit-approval'])

# build contrastive learner, set supervised=True for supervised VPCL
model, collate_fn = transtab.build_contrastive_learner(
    cat_cols, num_cols, bin_cols,
    supervised=True, # if take supervised CL
    num_partition=4, # num of column partitions for pos/neg sampling
    overlap_ratio=0.5, # specify the overlap ratio of column partitions during the CL
)

# start contrastive pretraining training
training_arguments = {
    'num_epoch':50,
    'batch_size':64,
    'lr':1e-4,
    'eval_metric':'val_loss',
    'eval_less_is_better':True,
    'output_dir':'./checkpoint' # save the pretrained model
    }

# pass the collate function to the train function
transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)
```

And after the pretraining completes, we can build a new classifier based on the pretrained model:

```python
# load the pretrained model and finetune on a target dataset
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
    = transtab.load_data('credit-approval')

# build transtab classifier model, and load from the pretrained dir
model = transtab.build_classifier(checkpoint='./checkpoint')

# update model's categorical/numerical/binary column dict
model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols})
```

It is easy 😎 !



## Conclusion

Duplicating the success of deep learning in CV & NLP to tabular learning domain still requires rethinking the basic elements. In CV, we have pixel; in NLP, we have token/word. In this paper, we propose a simple yet effective algorithm to model tabular data. Our method explores using NLP techniques for enhancing tabular learning, with flexibility to handle arbitrary input tables. We hope it appeals to more attention in deep learning for tabular learning.



## References
<a id="1"> [1] </a> Jinsung Yoon, Yao Zhang, James Jordon, and Mihaela van der Schaar. VIME: Extending the success of self-and semi-supervised learning to tabular domain. Advances in Neural Information Processing Systems, 33:11033–11043, 2020.

<a id="2"> [2] </a> Dara Bahri, Heinrich Jiang, Yi Tay, and Donald Metzler. SCARF: Self-supervised contrastive learning using random feature corruption. In International Conference on Learning Representations, 2022.

<a id="3"> [3] </a> Talip Ucar, Ehsan Hajiramezanali, and Lindsay Edwards. SubTab: Subsetting features of tabular data for self-supervised representation learning. Advances in Neural Information Processing Systems, 34, 2021.


================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS    ?=
SPHINXBUILD   ?= sphinx-build
SOURCEDIR     = source
BUILDDIR      = build

# Put it first so that "make" without argument is like "make help".
help:
	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)


================================================
FILE: docs/make.bat
================================================
@ECHO OFF

pushd %~dp0

REM Command file for Sphinx documentation

if "%SPHINXBUILD%" == "" (
	set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
	echo.
	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
	echo.installed, then set the SPHINXBUILD environment variable to point
	echo.to the full path of the 'sphinx-build' executable. Alternatively you
	echo.may add the Sphinx directory to PATH.
	echo.
	echo.If you don't have Sphinx installed, grab it from
	echo.https://www.sphinx-doc.org/
	exit /b 1
)

if "%1" == "" goto help

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd


================================================
FILE: docs/requirements.txt
================================================
sphinx-markdown-tables
recommonmark
sphinx==4.2.0
sphinx_rtd_theme==1.0.0
readthedocs-sphinx-search==0.1.1
loguru
numpy
scikit_learn
setuptools
transformers
tqdm
pandas>=1.3.*
openml>=0.10.0
torch


================================================
FILE: docs/source/about.rst
================================================
About Us
========

This package was developed and maintained by Zifeng Wang (Ph.D. Student @ UIUC).

Please refer to his `Homepage <https://zifengwang.xyz/>`_ for more details.

================================================
FILE: docs/source/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
import pdb

sys.path.insert(0, os.path.abspath('../../'))

# -- Project information -----------------------------------------------------

project = 'transtab'
copyright = '2022, Zifeng Wang'
author = 'Zifeng Wang'

# The full version, including alpha/beta/rc tags
release = 'alpha'

# Override the RTD default master doc
master_doc = 'index'

# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
    'recommonmark',
    'sphinx_markdown_tables',
    'sphinx.ext.intersphinx',
    'sphinx.ext.imgmath',
    'sphinx.ext.viewcode',
    'sphinx.ext.napoleon',
    'sphinx.ext.autodoc',
]

napoleon_google_docstring = False
napoleon_numpy_docstring = True

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages.  See the documentation for
# a list of builtin themes.
#
# html_theme = 'alabaster'
html_theme = 'sphinx_rtd_theme'
# html_theme = 'furo'

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']


================================================
FILE: docs/source/data_preparation.rst
================================================
Custom Dataset
==============

Here is the best practice to build your own datasets for `transtab`.

::

    project
    |
    ├── run_your_model.py
    |
    └─── data
         |
         ├── dataset1
         |   |    data_processed.csv
         |   |    binary_feature.txt
         |   └─── numerical_feature.txt
         |
         ├── dataset2
         |   
        ...

where the ``run_your_model.py`` is the code where you will load the dataset and train your models.

You should put the preprocessed table into ``data_processed.csv``, which is better to follow the protocols:

* All the column names to be represented by meaningful natural languge.
* All the categorical features to be represented by meaningful natural language.
* All the binary features to be represented by 0 or 1.
* All the numerical features to be represented by continuous values.
* Store the processed table into ``data_processed.csv``.
* Store the binary column names into ``binary_feature.txt``. No need to create this file if no binary feature.
* Store the numerical column names into ``numerical_feature.txt``. No need to create this file if no numerical feature.
* All the other columns will be treated as categorical or textual.

After that, you can try to load the dataset by


.. code-block:: python

    transtab.load_data('./data/dataset1')


About ``dataset_config``, an example is provided as

.. code-block:: python

    EXAMPLE_DATACONFIG = {
        "example": { # dataset name
            "bin": ["bin1", "bin2"], # binary column names
            "cat": ["cat1", "cat2"], # categorical column names
            "num": ["num1", "num2"], # numerical column names
            "cols": ["bin1", "bin2", "cat1", "cat2", "num1", "num2"], # all column names
            "binary_indicator": ["1", "yes", "true", "positive", "t", "y"], # binary indicators in the binary columns, which will be converted to 1
            "data_split_idx": {
                "train":[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], # row indices for training set
                "val":[10, 11, 12, 13, 14, 15, 16, 17, 18, 19], # row indices for validation set
                "test":[20, 21, 22, 23, 24, 25, 26, 27, 28, 29], # row indices for test set
                }
            }
        }



================================================
FILE: docs/source/example_encode.rst
================================================
Encode Tables
=============

*transtab* is able to take pd.DataFrame as inputs and outputs the encoded sample-level embeddings.
The full code is available at `Notebook Example <https://github.com/ryanwangzf/transtab/blob/master/examples/table_embedding.ipynb>`_.


.. code-block:: python

    import transtab

    # load a dataset and start vanilla supervised training
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('credit-g')
    
    # build transtab classifier model
    model, collate_fn = transtab.build_contrastive_learner(cat_cols, num_cols, bin_cols)

    # start training
    training_arguments = {
        'num_epoch':50,
        'batch_size':64,
        'lr':1e-4,
        'eval_metric':'val_loss',
        'eval_less_is_better':True,
        'output_dir':'./checkpoint'
        }
    transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)

Now we have obtained the pretrained model saved in './checkpoint', we can load the model
from this path and use it to encode tables.


.. code-block:: python

    # load the pretrained model
    enc = transtab.build_encoder(
        binary_columns=bin_cols,
        checkpoint = './checkpoint'
    )

Then we can take the whole pretrained model and output the cls token embedding at the last layer's outputs

.. code-block:: python

    # encode tables to sample-level embeddings
    df = trainset[0]
    output = enc(df)


================================================
FILE: docs/source/example_pretrain.rst
================================================
Tabular Pretraining
===================

When encountering multiple distinct tables which may have different number of classes, performing
contrastive pretraining (called Vertical-Partition Contrastive Learning, VPCL in the paper) is often
a better choice. This can be done using the transtab contrastive learner model.
The full code is available at `Notebook Example <https://github.com/ryanwangzf/transtab/blob/master/examples/contrastive_learning.ipynb>`_.


.. code-block:: python

    import transtab

    # load multiple datasets by passing a list of data names
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data(['credit-g','credit-approval'])

    # build contrastive learner, set supervised=True for supervised VPCL
    model, collate_fn = transtab.build_contrastive_learner(
        cat_cols, num_cols, bin_cols, 
        supervised=True, # if take supervised CL
        num_partition=4, # num of column partitions for pos/neg sampling
        overlap_ratio=0.5, # specify the overlap ratio of column partitions during the CL
    )


The function transtab.build_contrastive_learner returns both the CL model and the collate function
for the training dataloaders. We then train the model like

.. code-block:: python

    # start contrastive pretraining training
    training_arguments = {
        'num_epoch':50,
        'batch_size':64,
        'lr':1e-4,
        'eval_metric':'val_loss',
        'eval_less_is_better':True,
        'output_dir':'./checkpoint' # save the pretrained model
        }

    # pass the collate function to the train function
    transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)

    
After this pretrain completes, we shall build a classifier from the checkpoint.

.. code-block:: python

    # load the pretrained model and finetune on a target dataset
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('credit-approval')

    # build transtab classifier model, and load from the pretrained dir
    model = transtab.build_classifier(checkpoint='./checkpoint')

    # update model's categorical/numerical/binary column dict
    model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols})



================================================
FILE: docs/source/example_transfer.rst
================================================
Tabular Transfer Learning
=========================

*transtab* is able to leverage the knowledge learned from broad data sources than finetunes on the target
data. It is also easy to fulfill it by this package.
The full code is available at `Notebook Example <https://github.com/ryanwangzf/transtab/blob/master/examples/transfer_learning.ipynb>`_.


.. code-block:: python

    import transtab

    # load a dataset and start vanilla supervised training
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('credit-g')
    
    # build transtab classifier model
    model = transtab.build_classifier(cat_cols, num_cols, bin_cols)

    # start training
    training_arguments = {
        'num_epoch':50,
        'eval_metric':'val_loss',
        'eval_less_is_better':True,
        'output_dir':'./checkpoint'
        }
    transtab.train(model, trainset, valset, **training_arguments)

Now we have obtained the pretrained model saved in './checkpoint', we can load the model
from this path and update the model with new samples and columns.


.. code-block:: python

    # now let's load another data and try to leverage the pretrained model for finetuning
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('credit-approval')

    # load the pretrained model
    model.load('./checkpoint')

    # update model's categorical/numerical/binary column dict
    model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols})


It should be noted if the finetune data differs the pretrain data on the number of classes, this should
be explicitly claimed in the update.

.. code-block:: python

    model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols, 'num_class':2})


Then we can continue to train the model just as same as done for supervised learning.

.. code-block:: python

    transtab.train(model, trainset, valset, **training_arguments)


================================================
FILE: docs/source/fast_train.rst
================================================
Fast Train with TransTab
=========================

*transtab* is featured for accepting variable-column tables for training and predicting. This is easy to be done
by this package.
The full code is available at `Notebook Example <https://github.com/ryanwangzf/transtab/blob/master/examples/fast_train.ipynb>`_.


.. code-block:: python

    import transtab

    # load multiple datasets by passing a list of data names
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data(['credit-g','credit-approval'])

    # build transtab classifier model
    model = transtab.build_classifier(cat_cols, num_cols, bin_cols)

    # specify training arguments, take validation loss for early stopping
    training_arguments = {
        'num_epoch':5, 
        'eval_metric':'val_loss',
        'eval_less_is_better':True,
        'output_dir':'./checkpoint'
        }


One can take the validation loss on the validation data of the first dataset *credit-g* only:

.. code-block:: python

    transtab.train(model, trainset, valset[0], **training_arguments)

or take the macro average loss on the validation set of both two datasets:

.. code-block:: python

    transtab.train(model, trainset, valset, **training_arguments)

After the training completes, we can load the best checkpoint judged by validation loss from the predefined *output_dir*
and make predictions.

.. code-block:: python

    model.load('./checkpoint')

    x_test, y_test = testset[0]

    ypred = transtab.predict(x_test)


.. warning::

    Under this pure supervised learning setting, all the passed datasets should have the 
    same **number of label classes**. For instance, here *credit-g* and *credit-approval* are both
    binary classification task. It is because the classifier of `transtab` only keeps one classification head 
    during the training and predicting.






================================================
FILE: docs/source/index.rst
================================================
Welcome to transtab documentation!
==================================

`transtab` is an easy-to-use **Python package** for flexible tabular prediction framework. **Tabular data** dominates the applications of machine learning in research & development, including healthcare, finance, advertising, engineering, etc.

`transtab` is featured for the following scenarios of tabular predictions:

* **Supervised learning**: the vanilla train and predict on tables with the identical columns.
* **Transfer learning**: given multiple labeled tables partially share columns, we enhance models for each of those tables by leveraging other tables.
* **Incremental learning**: as a table incrementally grows with more columns, we update the existing model to handle the new table with more columns.
* **Table Pretraining**: we pretrain models on many tables with distinct columns and identifiers for the target tabular prediction task.
* **Zero-shot inference**: we build a model for an unseen table that only has partial overlaps with training tables.

.. figure:: ../images/transtab_tasks.png

    The demonstration of ML modeling on different tabular data settings.
    Previous tabular methods only do vanilla supervised training or pretraining on the same table due to they only accept
    **fixed-column tables**. By contrast, \method covers more new tasks (1) to (4) as it accepts **variable-column** tables.


The basic usage of `transtab` can be done in a couple of lines:

.. code-block:: python

    import transtab

    # load dataset by specifying dataset name
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('credit-g')

    # build classifier
    model = transtab.build_classifier(cat_cols, num_cols, bin_cols)

    # start training
    transtab.train(model, trainset, valset, **training_arguments)

    # make predictions, df_x is a pd.DataFrame with shape (n, d)
    # return the predictions ypred with shape (n, 1) if binary classification;
    # (n, n_class) if multiclass classification.
    ypred = transtab.predict(model, df_x)


It's easy, isn't it?

Let's start the journey from the `installation <https://transtab.readthedocs.io/en/latest/install.html>`_ and the `first demo on supervised tabular learning <https://transtab.readthedocs.io/en/latest/fast_train.html>`_ !

We also provide the examples on `tabular transfer learning <https://transtab.readthedocs.io/en/latest/example_transfer.html>`_ and `tabular pretraining <https://transtab.readthedocs.io/en/latest/example_pretrain.html>`_ for the quick start.

----

**Citing transtab**:

If you use `transtab` in a scientific publication, we would appreciate citations to the following paper::

    @article{wang2022transtab,
        author = {Wang, Zifeng and Sun, Jimeng},
        title = {TransTab: Learning Transferable Tabular Transformers Across Tables},
        journal={arXiv preprint arXiv:2205.09328},
        year = {2022},
    }


.. toctree::
   :maxdepth: 2
   :hidden:
   :caption: Getting Started

   install
   fast_train
   example_transfer
   example_pretrain
   example_encode
   data_preparation

.. toctree::
    :maxdepth: 2
    :hidden:
    :caption: Documentation

    main_func
    models


.. toctree::
    :maxdepth: 2
    :hidden:
    :caption: Additional Information

    about


================================================
FILE: docs/source/install.rst
================================================
Installation
============

*transtab* was tested on Python 3.7+, PyTorch 1.8.0+. Please follow the Installation instructions below for the
torch version and CUDA device you are using:

`PyTorch Installation Instructions <https://pytorch.org/get-started/locally/>`_.

After that, *transtab* can be downloaded directly using **pip**. [Feb 2025, the PyPI version is no longer maintained, please try to install it from github]:

.. code-block:: bash

    pip install git+https://github.com/RyanWangZf/transtab.git


Alternatively, you can clone the project and install it from local

.. code-block:: bash

    git clone https://github.com/RyanWangZf/transtab.git
    cd transtab
    pip install .

**Troubleshooting**:

1. If encountering ``ERROR: Failed building wheel for tokenizers`` on MAC/Linux, please call

.. code-block:: bash

    curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

then restart the terminal and call ``pip`` again.


================================================
FILE: docs/source/main_func.rst
================================================
Main Functions
==============

.. toctree::
    load_data<transtab.load_data>
    build_classifier<transtab.build_classifier>
    build_contrastive_learner<transtab.build_contrastive_learner>
    build_encoder<transtab.build_encoder>
    build_extractor<transtab.build_extractor>
    train<transtab.train>
    predict<transtab.predict>

================================================
FILE: docs/source/models.rst
================================================
Models
======

.. toctree::
    BaseModel<transtab.basemodel>
    TransTabClassifier<transtab.classifier>
    TransTabForCL<transtab.contrastive>

================================================
FILE: docs/source/transtab.basemodel.rst
================================================
TransTabModel
=============

.. automodule:: transtab.modeling_transtab
    :members: TransTabModel
    :no-undoc-members:
    :no-show-inheritance:

================================================
FILE: docs/source/transtab.build_classifier.rst
================================================
build_classifier
================

.. autofunction:: transtab.build_classifier

.. warning::
    If ``categorical_columns``,  ``numerical_columns``, and ``binary_columns`` are **ALL** not specified, the model takes **ALL** as ``categorical columns``,
    which may undermine the performance significantly.


================================================
FILE: docs/source/transtab.build_contrastive_learner.rst
================================================
build_contrastive_learner
=========================

.. autofunction:: transtab.build_contrastive_learner


================================================
FILE: docs/source/transtab.build_encoder.rst
================================================
build_extractor
===============

.. autofunction:: transtab.build_encoder

The returned feature extractor takes pd.DataFrame as inputs and outputs the
encoded sample-level embeddings.

.. code-block:: python

    # build the feature extractor
    enc = transtab.build_encoder(categorical_columns=['gender'], numerical_columns=['age'])

    # build a table for inputs
    df = pd.DataFrame({'age':[1,2], 'gender':['male','female']})

    # extract the outputs
    outputs = enc(df)

    print(outputs.shape)

    '''
    torch.Size([2, 128])
    '''

================================================
FILE: docs/source/transtab.build_extractor.rst
================================================
build_extractor
===============

.. autofunction:: transtab.build_extractor


The returned feature extractor takes pd.DataFrame as inputs and outputs the
encoded outputs in dict.

.. code-block:: python

    # build the feature extractor
    extractor = transtab.build_extractor(categorical_columns=['gender'], numerical_columns=['age'])

    # build a table for inputs
    df = pd.DataFrame({'age':[1,2], 'gender':['male','female']})

    # extract the outputs
    outputs = extractor(df)

    print(outputs)

    '''
        {
        'x_num': tensor([[1.],[2.]], dtype=torch.float64),
        'num_col_input_ids': tensor([[2287]]),
        'x_cat_input_ids': tensor([[5907, 3287], [5907, 2931]]),
        'x_bin_input_ids': None,
        'num_att_mask': tensor([[1]]),
        'cat_att_mask': tensor([[1, 1], [1, 1]])
        }
    '''


================================================
FILE: docs/source/transtab.classifier.rst
================================================
TransTabClassifier
==================

.. autoclass:: transtab.modeling_transtab.TransTabClassifier
    :members:
    :no-undoc-members:
    :no-show-inheritance:


================================================
FILE: docs/source/transtab.contrastive.rst
================================================
TransTabForCL
=============

.. autoclass:: transtab.modeling_transtab.TransTabForCL
    :members:
    :no-undoc-members:
    :no-show-inheritance:


================================================
FILE: docs/source/transtab.load_data.rst
================================================
load_data
=========

.. autofunction:: transtab.load_data


*transtab* provides flexible data loading function.
It can be used to load arbitrary datasets from `openml <https://www.openml.org/>`_ supported by `openml.datasets API <https://docs.openml.org/Python-API/>`_.

.. code-block:: python

    # specify the dataname
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('credit-g')

    # or specify the dataset index (in openml)
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data(31)

It can also be used to load datasets from the local device.

.. code-block:: python

    # specify the dataset dir
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('./data/credit-g')


Another important feature is to use this function to load multiple datasets

.. code-block:: python

    # specify the dataset dir
    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data(['./data/credit-g','./data/credit-approval'])

One can also pass ``dataset_config`` to the ``load_data`` function to manipulate the input table directly.

.. code-block:: python

    # customize dataset configuration
    dataset_config = {
        'credit-g':{
            'columns':['a','b','c'], # specify the new columns for the table, should keep the same dimension as the original table.
            'cat':['a'], # specify all the categorical columns
            'bin':['b'], # specify all the binary columns
            'num':['c']} # specify all the numerical columns
            }

    allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
        = transtab.load_data('credit-g', dataset_config=dataset_config)


While this operation is not recommended. To avoid making errors, you'd better deposit all these configurations to the local following
the guidance of `custom dataset <https://transtab.readthedocs.io/en/latest/data_preparation.html>`_.


================================================
FILE: docs/source/transtab.predict.rst
================================================
predict
=======

.. autofunction:: transtab.predict


================================================
FILE: docs/source/transtab.train.rst
================================================
train
=====

.. autofunction:: transtab.train


================================================
FILE: docs/sphinx-commands.txt
================================================
# build html files
sphinx-build -b html source build

================================================
FILE: examples/contrastive_learning.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0c0001bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "865b42a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n",
      "########################################\n",
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    }
   ],
   "source": [
    "# load multiple datasets by passing a list of data names\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "    = transtab.load_data(['credit-g','credit-approval'])\n",
    "\n",
    "# build contrastive learner, set supervised=True for supervised VPCL\n",
    "model, collate_fn = transtab.build_contrastive_learner(\n",
    "    cat_cols, num_cols, bin_cols, \n",
    "    supervised=True, # if take supervised CL\n",
    "    num_partition=4, # num of column partitions for pos/neg sampling\n",
    "    overlap_ratio=0.5, # specify the overlap ratio of column partitions during the CL\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "78d0bc6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a6a12cd244e4672b360c68222c7b7f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 5.794664\n",
      "epoch: 0, train loss: 105.4182, lr: 0.000100, spent: 1.1 secs\n",
      "epoch: 1, test val_loss: 5.786065\n",
      "epoch: 1, train loss: 104.5511, lr: 0.000100, spent: 2.0 secs\n",
      "epoch: 2, test val_loss: 5.781867\n",
      "epoch: 2, train loss: 104.5076, lr: 0.000100, spent: 3.0 secs\n",
      "epoch: 3, test val_loss: 5.777907\n",
      "epoch: 3, train loss: 104.4728, lr: 0.000100, spent: 4.1 secs\n",
      "epoch: 4, test val_loss: 5.775703\n",
      "epoch: 4, train loss: 104.4284, lr: 0.000100, spent: 5.0 secs\n",
      "epoch: 5, test val_loss: 5.772933\n",
      "epoch: 5, train loss: 104.4126, lr: 0.000100, spent: 6.0 secs\n",
      "epoch: 6, test val_loss: 5.771537\n",
      "epoch: 6, train loss: 104.3681, lr: 0.000100, spent: 6.9 secs\n",
      "epoch: 7, test val_loss: 5.768374\n",
      "epoch: 7, train loss: 104.3112, lr: 0.000100, spent: 7.8 secs\n",
      "epoch: 8, test val_loss: 5.766492\n",
      "epoch: 8, train loss: 104.3186, lr: 0.000100, spent: 8.8 secs\n",
      "epoch: 9, test val_loss: 5.763317\n",
      "epoch: 9, train loss: 104.2437, lr: 0.000100, spent: 9.7 secs\n",
      "epoch: 10, test val_loss: 5.763273\n",
      "epoch: 10, train loss: 104.2665, lr: 0.000100, spent: 10.8 secs\n",
      "epoch: 11, test val_loss: 5.758865\n",
      "epoch: 11, train loss: 104.2031, lr: 0.000100, spent: 12.0 secs\n",
      "epoch: 12, test val_loss: 5.761363\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 12, train loss: 104.2412, lr: 0.000100, spent: 13.1 secs\n",
      "epoch: 13, test val_loss: 5.760094\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 13, train loss: 104.2192, lr: 0.000100, spent: 14.4 secs\n",
      "epoch: 14, test val_loss: 5.756854\n",
      "epoch: 14, train loss: 104.1880, lr: 0.000100, spent: 15.7 secs\n",
      "epoch: 15, test val_loss: 5.755385\n",
      "epoch: 15, train loss: 104.1087, lr: 0.000100, spent: 17.0 secs\n",
      "epoch: 16, test val_loss: 5.755942\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 16, train loss: 104.1531, lr: 0.000100, spent: 18.3 secs\n",
      "epoch: 17, test val_loss: 5.758205\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 17, train loss: 104.2000, lr: 0.000100, spent: 19.4 secs\n",
      "epoch: 18, test val_loss: 5.748805\n",
      "epoch: 18, train loss: 104.0332, lr: 0.000100, spent: 20.5 secs\n",
      "epoch: 19, test val_loss: 5.748421\n",
      "epoch: 19, train loss: 104.0516, lr: 0.000100, spent: 21.8 secs\n",
      "epoch: 20, test val_loss: 5.749574\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 20, train loss: 104.0346, lr: 0.000100, spent: 22.9 secs\n",
      "epoch: 21, test val_loss: 5.749054\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 21, train loss: 104.0557, lr: 0.000100, spent: 23.9 secs\n",
      "epoch: 22, test val_loss: 5.752270\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 22, train loss: 104.0468, lr: 0.000100, spent: 25.1 secs\n",
      "epoch: 23, test val_loss: 5.749521\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 23, train loss: 104.0925, lr: 0.000100, spent: 26.1 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:56:45.227 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 10:56:45.242 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n",
      "2022-08-31 10:56:45.379 | INFO     | transtab.trainer:train:137 - training complete, cost 27.2 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 24, test val_loss: 5.751015\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# start contrastive pretraining training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'batch_size':64,\n",
    "    'lr':1e-4,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }\n",
    "\n",
    "transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "85e9ad3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:56:48.450 | WARNING  | transtab.modeling_transtab:_check_column_overlap:254 - No cat/num/bin cols specified, will take ALL columns as categorical! Ignore this warning if you specify the `checkpoint` to load the model.\n",
      "2022-08-31 10:56:48.527 | INFO     | transtab.modeling_transtab:load:782 - missing keys: ['clf.fc.weight', 'clf.fc.bias', 'clf.norm.weight', 'clf.norm.bias']\n",
      "2022-08-31 10:56:48.528 | INFO     | transtab.modeling_transtab:load:783 - unexpected keys: ['projection_head.dense.weight']\n",
      "2022-08-31 10:56:48.528 | INFO     | transtab.modeling_transtab:load:784 - load model from ./checkpoint\n",
      "2022-08-31 10:56:48.542 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json\n",
      "2022-08-31 10:56:48.556 | WARNING  | transtab.modeling_transtab:_check_column_overlap:254 - No cat/num/bin cols specified, will take ALL columns as categorical! Ignore this warning if you specify the `checkpoint` to load the model.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8bc6cedea8c74fa0a79a6201160b8641",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 0.683971\n",
      "epoch: 0, train loss: 5.4453, lr: 0.000100, spent: 0.3 secs\n",
      "epoch: 1, test val_loss: 0.646593\n",
      "epoch: 1, train loss: 5.2291, lr: 0.000100, spent: 0.6 secs\n",
      "epoch: 2, test val_loss: 0.598986\n",
      "epoch: 2, train loss: 4.9122, lr: 0.000100, spent: 0.8 secs\n",
      "epoch: 3, test val_loss: 0.571086\n",
      "epoch: 3, train loss: 4.6084, lr: 0.000100, spent: 1.1 secs\n",
      "epoch: 4, test val_loss: 0.500248\n",
      "epoch: 4, train loss: 4.2688, lr: 0.000100, spent: 1.3 secs\n",
      "epoch: 5, test val_loss: 0.461829\n",
      "epoch: 5, train loss: 3.8759, lr: 0.000100, spent: 1.6 secs\n",
      "epoch: 6, test val_loss: 0.418263\n",
      "epoch: 6, train loss: 3.5448, lr: 0.000100, spent: 1.9 secs\n",
      "epoch: 7, test val_loss: 0.406784\n",
      "epoch: 7, train loss: 3.3226, lr: 0.000100, spent: 2.2 secs\n",
      "epoch: 8, test val_loss: 0.415289\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 8, train loss: 3.2534, lr: 0.000100, spent: 2.5 secs\n",
      "epoch: 9, test val_loss: 0.395700\n",
      "epoch: 9, train loss: 3.1036, lr: 0.000100, spent: 2.7 secs\n",
      "epoch: 10, test val_loss: 0.477691\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 10, train loss: 2.9625, lr: 0.000100, spent: 3.2 secs\n",
      "epoch: 11, test val_loss: 0.394624\n",
      "epoch: 11, train loss: 2.9855, lr: 0.000100, spent: 3.5 secs\n",
      "epoch: 12, test val_loss: 0.395159\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 12, train loss: 3.0646, lr: 0.000100, spent: 3.7 secs\n",
      "epoch: 13, test val_loss: 0.520994\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 13, train loss: 3.0765, lr: 0.000100, spent: 4.0 secs\n",
      "epoch: 14, test val_loss: 0.388927\n",
      "epoch: 14, train loss: 3.0590, lr: 0.000100, spent: 4.3 secs\n",
      "epoch: 15, test val_loss: 0.447461\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 15, train loss: 2.8070, lr: 0.000100, spent: 4.5 secs\n",
      "epoch: 16, test val_loss: 0.402370\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 16, train loss: 2.6713, lr: 0.000100, spent: 4.7 secs\n",
      "epoch: 17, test val_loss: 0.393792\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 17, train loss: 2.7131, lr: 0.000100, spent: 5.0 secs\n",
      "epoch: 18, test val_loss: 0.455256\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 18, train loss: 2.7538, lr: 0.000100, spent: 5.2 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:56:53.974 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 10:56:54.000 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n",
      "2022-08-31 10:56:54.130 | INFO     | transtab.trainer:train:137 - training complete, cost 5.6 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 19, test val_loss: 0.406734\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# load the pretrained model and finetune on a target dataset\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "     = transtab.load_data('credit-approval')\n",
    "\n",
    "# build transtab classifier model, and load from the pretrained dir\n",
    "model = transtab.build_classifier(checkpoint='./checkpoint')\n",
    "\n",
    "# update model's categorical/numerical/binary column dict\n",
    "model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols})\n",
    "\n",
    "# start finetuning\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }\n",
    "transtab.train(model, trainset, valset, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ba5e5238",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc 0.95 mean/interval 0.8382(0.06)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.8382272091644043]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluation\n",
    "x_test, y_test = testset\n",
    "ypred = transtab.predict(model, x_test)\n",
    "transtab.evaluate(ypred, y_test, metric='auc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5d6d70",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: examples/fast_train.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0bc8ef17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e06b2eb3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n",
      "########################################\n",
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    }
   ],
   "source": [
    "# load multiple datasets by passing a list of data names\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "    = transtab.load_data(['credit-g','credit-approval'])\n",
    "\n",
    "# build transtab classifier model\n",
    "model = transtab.build_classifier(cat_cols, num_cols, bin_cols)\n",
    "\n",
    "# specify training arguments, take validation loss for early stopping\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'batch_size':128,\n",
    "    'lr':1e-4,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f0c84e5f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>own_telephone</th>\n",
       "      <th>foreign_worker</th>\n",
       "      <th>duration</th>\n",
       "      <th>credit_amount</th>\n",
       "      <th>installment_commitment</th>\n",
       "      <th>residence_since</th>\n",
       "      <th>age</th>\n",
       "      <th>existing_credits</th>\n",
       "      <th>num_dependents</th>\n",
       "      <th>checking_status</th>\n",
       "      <th>credit_history</th>\n",
       "      <th>purpose</th>\n",
       "      <th>savings_status</th>\n",
       "      <th>employment</th>\n",
       "      <th>personal_status</th>\n",
       "      <th>other_parties</th>\n",
       "      <th>property_magnitude</th>\n",
       "      <th>other_payment_plans</th>\n",
       "      <th>housing</th>\n",
       "      <th>job</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>636</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.061957</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.160714</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>500&lt;=X&lt;1000</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>182</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.076868</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.375000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>all paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>736</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.622318</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.071429</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>922</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.073529</td>\n",
       "      <td>0.061406</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.053571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>511</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.470588</td>\n",
       "      <td>0.244085</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.232143</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>no known property</td>\n",
       "      <td>none</td>\n",
       "      <td>for free</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>845</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.205018</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.285714</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>492</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.029412</td>\n",
       "      <td>0.054308</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>100&lt;=X&lt;500</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>849</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.117647</td>\n",
       "      <td>0.025256</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.678571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>real estate</td>\n",
       "      <td>stores</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>297</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.088235</td>\n",
       "      <td>0.057060</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.464286</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>co applicant</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.470588</td>\n",
       "      <td>0.114834</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.303571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>700 rows × 20 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     own_telephone  foreign_worker  duration  credit_amount  \\\n",
       "636              0               1  0.294118       0.061957   \n",
       "182              0               1  0.250000       0.076868   \n",
       "736              0               1  0.294118       0.622318   \n",
       "922              0               1  0.073529       0.061406   \n",
       "511              1               1  0.470588       0.244085   \n",
       "..             ...             ...       ...            ...   \n",
       "845              1               1  0.250000       0.205018   \n",
       "492              0               1  0.029412       0.054308   \n",
       "849              0               1  0.117647       0.025256   \n",
       "297              0               0  0.088235       0.057060   \n",
       "98               0               1  0.470588       0.114834   \n",
       "\n",
       "     installment_commitment  residence_since       age  existing_credits  \\\n",
       "636                1.000000         0.000000  0.160714          0.000000   \n",
       "182                1.000000         0.333333  0.375000          0.333333   \n",
       "736                0.000000         1.000000  0.071429          0.333333   \n",
       "922                0.666667         1.000000  0.053571          0.000000   \n",
       "511                0.333333         0.333333  0.232143          0.000000   \n",
       "..                      ...              ...       ...               ...   \n",
       "845                0.333333         0.666667  0.285714          0.000000   \n",
       "492                0.000000         0.000000  0.142857          0.333333   \n",
       "849                1.000000         1.000000  0.678571          0.000000   \n",
       "297                1.000000         0.333333  0.464286          0.000000   \n",
       "98                 1.000000         1.000000  0.303571          0.000000   \n",
       "\n",
       "     num_dependents checking_status                  credit_history  \\\n",
       "636             0.0     no checking                   existing paid   \n",
       "182             1.0              <0                        all paid   \n",
       "736             0.0        0<=X<200                   existing paid   \n",
       "922             0.0              <0                   existing paid   \n",
       "511             0.0     no checking                   existing paid   \n",
       "..              ...             ...                             ...   \n",
       "845             0.0        0<=X<200                   existing paid   \n",
       "492             0.0     no checking  critical/other existing credit   \n",
       "849             0.0              <0                   existing paid   \n",
       "297             0.0     no checking                   existing paid   \n",
       "98              0.0        0<=X<200  critical/other existing credit   \n",
       "\n",
       "                 purpose    savings_status employment     personal_status  \\\n",
       "636             radio/tv       500<=X<1000     4<=X<7  female div/dep/mar   \n",
       "182              new car  no known savings     1<=X<4         male single   \n",
       "736             used car              <100     1<=X<4  female div/dep/mar   \n",
       "922             radio/tv              <100         <1  female div/dep/mar   \n",
       "511             used car              <100     1<=X<4         male single   \n",
       "..                   ...               ...        ...                 ...   \n",
       "845  furniture/equipment  no known savings     4<=X<7         male single   \n",
       "492             radio/tv        100<=X<500     1<=X<4  female div/dep/mar   \n",
       "849             radio/tv              <100        >=7         male single   \n",
       "297              new car  no known savings        >=7         male single   \n",
       "98              radio/tv              <100        >=7         male single   \n",
       "\n",
       "    other_parties property_magnitude other_payment_plans   housing  \\\n",
       "636          none                car                none       own   \n",
       "182          none     life insurance                none       own   \n",
       "736          none                car                none      rent   \n",
       "922          none     life insurance                none      rent   \n",
       "511          none  no known property                none  for free   \n",
       "..            ...                ...                 ...       ...   \n",
       "845          none                car                none       own   \n",
       "492          none     life insurance                none       own   \n",
       "849          none        real estate              stores       own   \n",
       "297  co applicant     life insurance                none       own   \n",
       "98           none        real estate                none       own   \n",
       "\n",
       "                           job  \n",
       "636                    skilled  \n",
       "182         unskilled resident  \n",
       "736  high qualif/self emp/mgmt  \n",
       "922                    skilled  \n",
       "511  high qualif/self emp/mgmt  \n",
       "..                         ...  \n",
       "845                    skilled  \n",
       "492                    skilled  \n",
       "849         unskilled resident  \n",
       "297         unskilled resident  \n",
       "98                     skilled  \n",
       "\n",
       "[700 rows x 20 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainset[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "058f667e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>own_telephone</th>\n",
       "      <th>foreign_worker</th>\n",
       "      <th>duration</th>\n",
       "      <th>credit_amount</th>\n",
       "      <th>installment_commitment</th>\n",
       "      <th>residence_since</th>\n",
       "      <th>age</th>\n",
       "      <th>existing_credits</th>\n",
       "      <th>num_dependents</th>\n",
       "      <th>checking_status</th>\n",
       "      <th>credit_history</th>\n",
       "      <th>purpose</th>\n",
       "      <th>savings_status</th>\n",
       "      <th>employment</th>\n",
       "      <th>personal_status</th>\n",
       "      <th>other_parties</th>\n",
       "      <th>property_magnitude</th>\n",
       "      <th>other_payment_plans</th>\n",
       "      <th>housing</th>\n",
       "      <th>job</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.205882</td>\n",
       "      <td>0.309013</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.196429</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>100&lt;=X&lt;500</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>924</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.364367</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.642857</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>all paid</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>male div/sep</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>bank</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>931</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.073529</td>\n",
       "      <td>0.078134</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.053571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>796</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.205882</td>\n",
       "      <td>0.399527</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.571429</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>for free</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>226</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.647059</td>\n",
       "      <td>0.589358</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&gt;=1000</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>co applicant</td>\n",
       "      <td>no known property</td>\n",
       "      <td>bank</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>380</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.235294</td>\n",
       "      <td>0.107956</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.357143</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>768</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.117647</td>\n",
       "      <td>0.185265</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.160714</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.117647</td>\n",
       "      <td>0.063937</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.178571</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>business</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>guarantor</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>527</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.068945</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.410714</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.088235</td>\n",
       "      <td>0.103555</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>co applicant</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 20 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     own_telephone  foreign_worker  duration  credit_amount  \\\n",
       "32               1               1  0.205882       0.309013   \n",
       "924              1               1  0.294118       0.364367   \n",
       "931              1               1  0.073529       0.078134   \n",
       "796              1               1  0.205882       0.399527   \n",
       "226              1               1  0.647059       0.589358   \n",
       "..             ...             ...       ...            ...   \n",
       "380              1               1  0.235294       0.107956   \n",
       "768              1               1  0.117647       0.185265   \n",
       "85               1               1  0.117647       0.063937   \n",
       "527              0               1  0.000000       0.068945   \n",
       "117              0               0  0.088235       0.103555   \n",
       "\n",
       "     installment_commitment  residence_since       age  existing_credits  \\\n",
       "32                 0.333333         0.333333  0.196429          0.333333   \n",
       "924                0.333333         0.000000  0.642857          0.000000   \n",
       "931                1.000000         0.333333  0.053571          0.000000   \n",
       "796                0.000000         1.000000  0.571429          0.000000   \n",
       "226                0.000000         0.333333  0.142857          0.333333   \n",
       "..                      ...              ...       ...               ...   \n",
       "380                1.000000         1.000000  0.357143          0.000000   \n",
       "768                0.000000         1.000000  0.160714          0.666667   \n",
       "85                 1.000000         0.333333  0.178571          0.333333   \n",
       "527                0.333333         0.000000  0.410714          0.333333   \n",
       "117                0.333333         0.666667  0.142857          0.333333   \n",
       "\n",
       "     num_dependents checking_status                  credit_history  \\\n",
       "32              0.0        0<=X<200                   existing paid   \n",
       "924             0.0              <0                        all paid   \n",
       "931             0.0        0<=X<200                   existing paid   \n",
       "796             1.0              <0                   existing paid   \n",
       "226             0.0        0<=X<200                   existing paid   \n",
       "..              ...             ...                             ...   \n",
       "380             0.0              <0                   existing paid   \n",
       "768             0.0        0<=X<200  critical/other existing credit   \n",
       "85              0.0     no checking  critical/other existing credit   \n",
       "527             1.0     no checking  critical/other existing credit   \n",
       "117             0.0              <0  critical/other existing credit   \n",
       "\n",
       "                 purpose    savings_status employment     personal_status  \\\n",
       "32               new car        100<=X<500     1<=X<4         male single   \n",
       "924  furniture/equipment              <100         <1        male div/sep   \n",
       "931             radio/tv              <100         <1  female div/dep/mar   \n",
       "796             used car  no known savings        >=7         male single   \n",
       "226             radio/tv            >=1000     4<=X<7         male single   \n",
       "..                   ...               ...        ...                 ...   \n",
       "380  furniture/equipment  no known savings     4<=X<7         male single   \n",
       "768  furniture/equipment              <100        >=7         male single   \n",
       "85              business              <100     1<=X<4  female div/dep/mar   \n",
       "527             radio/tv              <100     4<=X<7         male single   \n",
       "117  furniture/equipment  no known savings         <1  female div/dep/mar   \n",
       "\n",
       "    other_parties property_magnitude other_payment_plans   housing  \\\n",
       "32           none                car                none       own   \n",
       "924          none     life insurance                bank       own   \n",
       "931          none                car                none       own   \n",
       "796          none     life insurance                none  for free   \n",
       "226  co applicant  no known property                bank       own   \n",
       "..            ...                ...                 ...       ...   \n",
       "380          none                car                none       own   \n",
       "768          none                car                none      rent   \n",
       "85      guarantor        real estate                none       own   \n",
       "527          none        real estate                none       own   \n",
       "117  co applicant        real estate                none      rent   \n",
       "\n",
       "                           job  \n",
       "32                     skilled  \n",
       "924                    skilled  \n",
       "931                    skilled  \n",
       "796                    skilled  \n",
       "226                    skilled  \n",
       "..                         ...  \n",
       "380                    skilled  \n",
       "768                    skilled  \n",
       "85   high qualif/self emp/mgmt  \n",
       "527         unskilled resident  \n",
       "117                    skilled  \n",
       "\n",
       "[100 rows x 20 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valset[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af2eed94",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3018579e308d4eb995ed65b3581b7f06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 0.624792\n",
      "epoch: 0, train loss: 6.6127, lr: 0.000100, spent: 0.5 secs\n",
      "epoch: 1, test val_loss: 0.599838\n",
      "epoch: 1, train loss: 6.3586, lr: 0.000100, spent: 1.2 secs\n",
      "epoch: 2, test val_loss: 0.593658\n",
      "epoch: 2, train loss: 6.0999, lr: 0.000100, spent: 1.8 secs\n",
      "epoch: 3, test val_loss: 0.550265\n",
      "epoch: 3, train loss: 5.8295, lr: 0.000100, spent: 2.3 secs\n",
      "epoch: 4, test val_loss: 0.527351\n",
      "epoch: 4, train loss: 5.6347, lr: 0.000100, spent: 2.8 secs\n",
      "epoch: 5, test val_loss: 0.508950\n",
      "epoch: 5, train loss: 5.5123, lr: 0.000100, spent: 3.3 secs\n",
      "epoch: 6, test val_loss: 0.485854\n",
      "epoch: 6, train loss: 5.4929, lr: 0.000100, spent: 3.9 secs\n",
      "epoch: 7, test val_loss: 0.522198\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 7, train loss: 5.6552, lr: 0.000100, spent: 4.3 secs\n",
      "epoch: 8, test val_loss: 0.478467\n",
      "epoch: 8, train loss: 5.7420, lr: 0.000100, spent: 4.7 secs\n",
      "epoch: 9, test val_loss: 0.515104\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 9, train loss: 5.3993, lr: 0.000100, spent: 5.3 secs\n",
      "epoch: 10, test val_loss: 0.474058\n",
      "epoch: 10, train loss: 5.3141, lr: 0.000100, spent: 5.8 secs\n",
      "epoch: 11, test val_loss: 0.473926\n",
      "epoch: 11, train loss: 5.2754, lr: 0.000100, spent: 6.3 secs\n",
      "epoch: 12, test val_loss: 0.470752\n",
      "epoch: 12, train loss: 5.1095, lr: 0.000100, spent: 6.8 secs\n",
      "epoch: 13, test val_loss: 0.478428\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 13, train loss: 5.0845, lr: 0.000100, spent: 7.3 secs\n",
      "epoch: 14, test val_loss: 0.454532\n",
      "epoch: 14, train loss: 5.1003, lr: 0.000100, spent: 8.0 secs\n",
      "epoch: 15, test val_loss: 0.462518\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 15, train loss: 5.0139, lr: 0.000100, spent: 8.5 secs\n",
      "epoch: 16, test val_loss: 0.453442\n",
      "epoch: 16, train loss: 4.9912, lr: 0.000100, spent: 9.1 secs\n",
      "epoch: 17, test val_loss: 0.459327\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 17, train loss: 4.9310, lr: 0.000100, spent: 9.5 secs\n",
      "epoch: 18, test val_loss: 0.442287\n",
      "epoch: 18, train loss: 4.8740, lr: 0.000100, spent: 10.2 secs\n",
      "epoch: 19, test val_loss: 0.466330\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 19, train loss: 4.8456, lr: 0.000100, spent: 10.8 secs\n",
      "epoch: 20, test val_loss: 0.436802\n",
      "epoch: 20, train loss: 4.7808, lr: 0.000100, spent: 11.2 secs\n",
      "epoch: 21, test val_loss: 0.472410\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 21, train loss: 4.7860, lr: 0.000100, spent: 11.6 secs\n",
      "epoch: 22, test val_loss: 0.448208\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 22, train loss: 4.9795, lr: 0.000100, spent: 12.2 secs\n",
      "epoch: 23, test val_loss: 0.426601\n",
      "epoch: 23, train loss: 4.8747, lr: 0.000100, spent: 12.8 secs\n",
      "epoch: 24, test val_loss: 0.556543\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 24, train loss: 4.9586, lr: 0.000100, spent: 13.2 secs\n",
      "epoch: 25, test val_loss: 0.455203\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 25, train loss: 4.9627, lr: 0.000100, spent: 13.8 secs\n",
      "epoch: 26, test val_loss: 0.581238\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 26, train loss: 5.0275, lr: 0.000100, spent: 14.2 secs\n",
      "epoch: 27, test val_loss: 0.501105\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 27, train loss: 5.2915, lr: 0.000100, spent: 14.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:57:39.041 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 10:57:39.057 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n",
      "2022-08-31 10:57:39.187 | INFO     | transtab.trainer:train:137 - training complete, cost 15.3 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 28, test val_loss: 0.461543\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# start training, take the validation loss on average for evaluation\n",
    "transtab.train(model, trainset, valset, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9b65a489",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make predictions on the first dataset 'credit-g'\n",
    "x_test, y_test = testset[0]\n",
    "ypred = transtab.predict(model, x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6eefaa05",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc 0.95 mean/interval 0.7399(0.06)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.7399011920073604]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate the predictions with bootstrapping estimate\n",
    "transtab.evaluate(ypred, y_test, seed=123, metric='auc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34d19852",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: examples/table_embedding.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9aa34ef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ce7052e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n"
     ]
    }
   ],
   "source": [
    "# load a dataset and start vanilla supervised training\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "    = transtab.load_data('credit-g')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4e709521",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0740c9e1a09844238618d786a971d916",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 6.349929\n",
      "epoch: 0, train loss: 72.9975, lr: 0.000100, spent: 1.3 secs\n",
      "epoch: 1, test val_loss: 6.043663\n",
      "epoch: 1, train loss: 62.8806, lr: 0.000100, spent: 2.2 secs\n",
      "epoch: 2, test val_loss: 5.999826\n",
      "epoch: 2, train loss: 61.3078, lr: 0.000100, spent: 3.0 secs\n",
      "epoch: 3, test val_loss: 5.989734\n",
      "epoch: 3, train loss: 61.0470, lr: 0.000100, spent: 3.9 secs\n",
      "epoch: 4, test val_loss: 5.986117\n",
      "epoch: 4, train loss: 60.9742, lr: 0.000100, spent: 4.8 secs\n",
      "epoch: 5, test val_loss: 5.984314\n",
      "epoch: 5, train loss: 60.9454, lr: 0.000100, spent: 5.8 secs\n",
      "epoch: 6, test val_loss: 5.983197\n",
      "epoch: 6, train loss: 60.9270, lr: 0.000100, spent: 6.7 secs\n",
      "epoch: 7, test val_loss: 5.982450\n",
      "epoch: 7, train loss: 60.9164, lr: 0.000100, spent: 7.6 secs\n",
      "epoch: 8, test val_loss: 5.981885\n",
      "epoch: 8, train loss: 60.9102, lr: 0.000100, spent: 8.5 secs\n",
      "epoch: 9, test val_loss: 5.981443\n",
      "epoch: 9, train loss: 60.9047, lr: 0.000100, spent: 9.5 secs\n",
      "epoch: 10, test val_loss: 5.981087\n",
      "epoch: 10, train loss: 60.9004, lr: 0.000100, spent: 10.3 secs\n",
      "epoch: 11, test val_loss: 5.980795\n",
      "epoch: 11, train loss: 60.8956, lr: 0.000100, spent: 11.3 secs\n",
      "epoch: 12, test val_loss: 5.980557\n",
      "epoch: 12, train loss: 60.8925, lr: 0.000100, spent: 12.3 secs\n",
      "epoch: 13, test val_loss: 5.980357\n",
      "epoch: 13, train loss: 60.8902, lr: 0.000100, spent: 13.3 secs\n",
      "epoch: 14, test val_loss: 5.980191\n",
      "epoch: 14, train loss: 60.8874, lr: 0.000100, spent: 14.5 secs\n",
      "epoch: 15, test val_loss: 5.980050\n",
      "epoch: 15, train loss: 60.8863, lr: 0.000100, spent: 15.5 secs\n",
      "epoch: 16, test val_loss: 5.979930\n",
      "epoch: 16, train loss: 60.8836, lr: 0.000100, spent: 16.4 secs\n",
      "epoch: 17, test val_loss: 5.979825\n",
      "epoch: 17, train loss: 60.8822, lr: 0.000100, spent: 17.3 secs\n",
      "epoch: 18, test val_loss: 5.979736\n",
      "epoch: 18, train loss: 60.8821, lr: 0.000100, spent: 18.2 secs\n",
      "epoch: 19, test val_loss: 5.979657\n",
      "epoch: 19, train loss: 60.8804, lr: 0.000100, spent: 19.2 secs\n",
      "epoch: 20, test val_loss: 5.979586\n",
      "epoch: 20, train loss: 60.8802, lr: 0.000100, spent: 20.3 secs\n",
      "epoch: 21, test val_loss: 5.979523\n",
      "epoch: 21, train loss: 60.8798, lr: 0.000100, spent: 21.3 secs\n",
      "epoch: 22, test val_loss: 5.979466\n",
      "epoch: 22, train loss: 60.8791, lr: 0.000100, spent: 22.2 secs\n",
      "epoch: 23, test val_loss: 5.979416\n",
      "epoch: 23, train loss: 60.8778, lr: 0.000100, spent: 23.2 secs\n",
      "epoch: 24, test val_loss: 5.979372\n",
      "epoch: 24, train loss: 60.8776, lr: 0.000100, spent: 24.2 secs\n",
      "epoch: 25, test val_loss: 5.979331\n",
      "epoch: 25, train loss: 60.8773, lr: 0.000100, spent: 25.1 secs\n",
      "epoch: 26, test val_loss: 5.979294\n",
      "epoch: 26, train loss: 60.8763, lr: 0.000100, spent: 26.0 secs\n",
      "epoch: 27, test val_loss: 5.979260\n",
      "epoch: 27, train loss: 60.8761, lr: 0.000100, spent: 27.0 secs\n",
      "epoch: 28, test val_loss: 5.979229\n",
      "epoch: 28, train loss: 60.8761, lr: 0.000100, spent: 27.9 secs\n",
      "epoch: 29, test val_loss: 5.979202\n",
      "epoch: 29, train loss: 60.8752, lr: 0.000100, spent: 28.9 secs\n",
      "epoch: 30, test val_loss: 5.979175\n",
      "epoch: 30, train loss: 60.8755, lr: 0.000100, spent: 29.8 secs\n",
      "epoch: 31, test val_loss: 5.979153\n",
      "epoch: 31, train loss: 60.8744, lr: 0.000100, spent: 30.8 secs\n",
      "epoch: 32, test val_loss: 5.979130\n",
      "epoch: 32, train loss: 60.8744, lr: 0.000100, spent: 31.6 secs\n",
      "epoch: 33, test val_loss: 5.979110\n",
      "epoch: 33, train loss: 60.8743, lr: 0.000100, spent: 32.4 secs\n",
      "epoch: 34, test val_loss: 5.979090\n",
      "epoch: 34, train loss: 60.8736, lr: 0.000100, spent: 33.4 secs\n",
      "epoch: 35, test val_loss: 5.979072\n",
      "epoch: 35, train loss: 60.8720, lr: 0.000100, spent: 34.3 secs\n",
      "epoch: 36, test val_loss: 5.979054\n",
      "epoch: 36, train loss: 60.8724, lr: 0.000100, spent: 35.2 secs\n",
      "epoch: 37, test val_loss: 5.979037\n",
      "epoch: 37, train loss: 60.8735, lr: 0.000100, spent: 36.2 secs\n",
      "epoch: 38, test val_loss: 5.979021\n",
      "epoch: 38, train loss: 60.8723, lr: 0.000100, spent: 36.9 secs\n",
      "epoch: 39, test val_loss: 5.979005\n",
      "epoch: 39, train loss: 60.8726, lr: 0.000100, spent: 37.8 secs\n",
      "epoch: 40, test val_loss: 5.978991\n",
      "epoch: 40, train loss: 60.8719, lr: 0.000100, spent: 38.5 secs\n",
      "epoch: 41, test val_loss: 5.978974\n",
      "epoch: 41, train loss: 60.8720, lr: 0.000100, spent: 39.3 secs\n",
      "epoch: 42, test val_loss: 5.978961\n",
      "epoch: 42, train loss: 60.8717, lr: 0.000100, spent: 40.1 secs\n",
      "epoch: 43, test val_loss: 5.978946\n",
      "epoch: 43, train loss: 60.8721, lr: 0.000100, spent: 40.9 secs\n",
      "epoch: 44, test val_loss: 5.978931\n",
      "epoch: 44, train loss: 60.8710, lr: 0.000100, spent: 41.8 secs\n",
      "epoch: 45, test val_loss: 5.978916\n",
      "epoch: 45, train loss: 60.8711, lr: 0.000100, spent: 42.7 secs\n",
      "epoch: 46, test val_loss: 5.978899\n",
      "epoch: 46, train loss: 60.8713, lr: 0.000100, spent: 43.6 secs\n",
      "epoch: 47, test val_loss: 5.978884\n",
      "epoch: 47, train loss: 60.8702, lr: 0.000100, spent: 44.6 secs\n",
      "epoch: 48, test val_loss: 5.978869\n",
      "epoch: 48, train loss: 60.8705, lr: 0.000100, spent: 45.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:15:16.839 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 14:15:16.853 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 49, test val_loss: 5.978854\n",
      "epoch: 49, train loss: 60.8699, lr: 0.000100, spent: 46.8 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:15:17.035 | INFO     | transtab.trainer:train:137 - training complete, cost 47.0 secs.\n"
     ]
    }
   ],
   "source": [
    "# make a fast pre-train of TransTab contrastive learning model\n",
    "# build contrastive learner, set supervised=True for supervised VPCL\n",
    "model, collate_fn = transtab.build_contrastive_learner(\n",
    "    cat_cols, num_cols, bin_cols, \n",
    "    supervised=True, # if take supervised CL\n",
    "    num_partition=4, # num of column partitions for pos/neg sampling\n",
    "    overlap_ratio=0.5, # specify the overlap ratio of column partitions during the CL\n",
    ")\n",
    "\n",
    "# start contrastive pretraining training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'batch_size':64,\n",
    "    'lr':1e-4,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }\n",
    "\n",
    "transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5c87e48b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:15:17.125 | INFO     | transtab.modeling_transtab:load:773 - missing keys: []\n",
      "2022-08-31 14:15:17.126 | INFO     | transtab.modeling_transtab:load:774 - unexpected keys: ['projection_head.dense.weight']\n",
      "2022-08-31 14:15:17.126 | INFO     | transtab.modeling_transtab:load:775 - load model from ./checkpoint\n",
      "2022-08-31 14:15:17.159 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json\n"
     ]
    }
   ],
   "source": [
    "# There are two ways to build the encoder\n",
    "# First, take the whole pretrained model and output the cls token embedding at the last layer's outputs\n",
    "enc = transtab.build_encoder(\n",
    "    binary_columns=bin_cols,\n",
    "    checkpoint = './checkpoint'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b8149cfa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([700, 128])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.2959e+00,  1.5239e+00, -1.2096e+00,  3.0303e-01,  7.4638e-01,\n",
       "          1.1758e+00,  1.1774e+00, -2.1921e-01,  4.2850e-01,  8.3295e-03,\n",
       "         -5.3477e-01,  1.4859e+00, -2.0534e+00, -9.4093e-01,  3.7010e-01,\n",
       "          1.3663e-01,  4.4837e-01,  1.3882e+00,  1.6472e+00, -1.2430e+00,\n",
       "         -4.8809e-01, -5.1914e-01, -3.3168e-01,  1.9889e+00, -4.9873e-01,\n",
       "          1.2286e+00,  8.6373e-01,  5.1300e-01,  6.7551e-01, -1.2021e+00,\n",
       "          6.3210e-01,  6.2366e-01,  5.6712e-01,  1.2275e-03, -1.5154e+00,\n",
       "          2.0082e+00, -1.2255e+00, -2.4254e-01, -5.1009e-01,  1.6733e+00,\n",
       "         -1.2059e+00, -7.0246e-01,  1.8980e-01, -7.8196e-01,  1.0777e+00,\n",
       "         -6.1830e-01, -1.1279e+00, -1.3290e+00,  9.6929e-01, -7.6388e-02,\n",
       "         -4.5835e-01, -1.1462e+00,  1.5084e+00,  5.7778e-01,  2.0644e-01,\n",
       "          4.3633e-01,  7.6116e-03,  5.2441e-01, -1.9919e-01, -1.9441e-01,\n",
       "          1.8144e+00,  2.7863e-01, -1.8727e+00, -9.4760e-01,  1.1152e+00,\n",
       "          3.5514e-01,  1.6321e+00,  4.3554e-01,  6.1438e-01,  2.2991e-01,\n",
       "          2.3567e-01,  1.0738e+00, -1.0689e+00,  1.1454e+00, -2.9430e-01,\n",
       "         -7.8866e-01,  1.7377e-01,  4.7786e-01, -1.1535e+00, -1.9210e+00,\n",
       "          5.6469e-01, -4.9142e-02, -6.4016e-01, -3.3013e-01, -3.1188e-01,\n",
       "         -7.4673e-01, -3.0021e-01, -2.0609e+00,  7.0935e-01, -6.6764e-01,\n",
       "          6.4810e-01, -8.1043e-02, -1.0044e+00, -2.1534e+00, -1.4149e+00,\n",
       "         -7.6418e-01,  1.9660e+00, -1.0766e+00, -5.2616e-01, -1.2752e+00,\n",
       "          1.1527e+00,  2.2518e-01,  1.7696e-01,  8.3931e-01, -3.5717e-01,\n",
       "          1.4251e-01,  1.6778e+00, -1.5331e+00, -1.5316e+00, -7.3143e-01,\n",
       "         -2.6362e-01, -5.3092e-01,  1.1220e+00,  9.4099e-01, -1.3653e+00,\n",
       "         -5.5385e-01, -2.5665e-01, -3.1621e-01, -1.3123e+00, -9.7127e-02,\n",
       "         -4.2603e-01,  1.8091e+00, -7.5452e-01,  1.9514e+00,  7.2433e-03,\n",
       "          3.7320e-02,  5.3549e-01, -3.9535e-01],\n",
       "        [ 1.4275e+00,  1.4772e+00, -1.1928e+00,  1.8642e-01,  8.1510e-01,\n",
       "          1.2602e+00,  1.2150e+00, -2.1353e-01,  3.9298e-01, -1.8265e-01,\n",
       "         -5.9739e-01,  1.2885e+00, -2.1044e+00, -1.0534e+00,  4.8087e-01,\n",
       "          1.2070e-01,  3.0839e-01,  1.2873e+00,  1.6255e+00, -1.0916e+00,\n",
       "         -3.2920e-01, -2.7017e-01, -3.4054e-01,  2.0612e+00, -6.5718e-01,\n",
       "          1.1547e+00,  9.0340e-01,  5.3138e-01,  7.4846e-01, -1.1599e+00,\n",
       "          6.1057e-01,  6.2320e-01,  6.3401e-01, -7.8121e-02, -1.5336e+00,\n",
       "          1.8799e+00, -1.4002e+00, -3.4578e-01, -8.7409e-01,  1.7005e+00,\n",
       "         -1.2923e+00, -5.9172e-01,  8.2113e-02, -7.6255e-01,  9.8186e-01,\n",
       "         -5.2740e-01, -1.1055e+00, -1.3655e+00,  8.0880e-01,  6.8788e-02,\n",
       "         -5.1715e-01, -1.2682e+00,  1.6060e+00,  5.9163e-01,  3.5197e-01,\n",
       "          6.1037e-01,  1.6449e-01,  4.7828e-01, -2.3575e-01, -2.4127e-01,\n",
       "          1.8397e+00,  3.7601e-01, -1.9676e+00, -9.4222e-01,  1.1711e+00,\n",
       "          3.2122e-01,  1.7164e+00,  4.7828e-01,  7.2740e-01,  2.1730e-01,\n",
       "          2.0191e-01,  7.4816e-01, -1.1957e+00,  1.2826e+00, -3.4407e-01,\n",
       "         -8.6727e-01,  1.4943e-01,  5.4311e-01, -1.1209e+00, -1.8852e+00,\n",
       "          5.8967e-01, -2.3814e-01, -6.1390e-01, -2.7548e-01, -2.5533e-01,\n",
       "         -8.5195e-01, -2.3613e-01, -1.9835e+00,  5.6644e-01, -5.9843e-01,\n",
       "          6.8693e-01,  3.4524e-02, -1.0214e+00, -1.8806e+00, -1.4108e+00,\n",
       "         -7.1087e-01,  1.9959e+00, -1.2109e+00, -6.3984e-01, -9.7635e-01,\n",
       "          1.1544e+00,  2.3031e-01,  2.3562e-01,  6.8024e-01, -2.9665e-01,\n",
       "          1.2141e-01,  1.7590e+00, -1.4833e+00, -1.4007e+00, -9.1892e-01,\n",
       "         -1.3863e-01, -3.3393e-01,  1.0803e+00,  1.0124e+00, -1.4227e+00,\n",
       "         -6.2524e-01, -1.6816e-01, -4.6652e-01, -1.3414e+00, -1.7069e-01,\n",
       "         -2.8513e-01,  1.7853e+00, -9.1653e-01,  1.7702e+00,  2.3768e-01,\n",
       "          9.3338e-02,  5.9862e-01, -3.1038e-01]], device='cuda:0',\n",
       "       grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Then take the encoder to get the input embedding\n",
    "df = trainset[0]\n",
    "output = enc(df)\n",
    "print(output.shape)\n",
    "output[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4aadae44",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>own_telephone</th>\n",
       "      <th>foreign_worker</th>\n",
       "      <th>duration</th>\n",
       "      <th>credit_amount</th>\n",
       "      <th>installment_commitment</th>\n",
       "      <th>residence_since</th>\n",
       "      <th>age</th>\n",
       "      <th>existing_credits</th>\n",
       "      <th>num_dependents</th>\n",
       "      <th>checking_status</th>\n",
       "      <th>credit_history</th>\n",
       "      <th>purpose</th>\n",
       "      <th>savings_status</th>\n",
       "      <th>employment</th>\n",
       "      <th>personal_status</th>\n",
       "      <th>other_parties</th>\n",
       "      <th>property_magnitude</th>\n",
       "      <th>other_payment_plans</th>\n",
       "      <th>housing</th>\n",
       "      <th>job</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>636</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.061957</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.160714</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>500&lt;=X&lt;1000</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>182</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.076868</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.375000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>all paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>736</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.622318</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.071429</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>922</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.073529</td>\n",
       "      <td>0.061406</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.053571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>511</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.470588</td>\n",
       "      <td>0.244085</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.232143</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>no known property</td>\n",
       "      <td>none</td>\n",
       "      <td>for free</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     own_telephone  foreign_worker  duration  credit_amount  \\\n",
       "636              0               1  0.294118       0.061957   \n",
       "182              0               1  0.250000       0.076868   \n",
       "736              0               1  0.294118       0.622318   \n",
       "922              0               1  0.073529       0.061406   \n",
       "511              1               1  0.470588       0.244085   \n",
       "\n",
       "     installment_commitment  residence_since       age  existing_credits  \\\n",
       "636                1.000000         0.000000  0.160714          0.000000   \n",
       "182                1.000000         0.333333  0.375000          0.333333   \n",
       "736                0.000000         1.000000  0.071429          0.333333   \n",
       "922                0.666667         1.000000  0.053571          0.000000   \n",
       "511                0.333333         0.333333  0.232143          0.000000   \n",
       "\n",
       "     num_dependents checking_status credit_history   purpose  \\\n",
       "636             0.0     no checking  existing paid  radio/tv   \n",
       "182             1.0              <0       all paid   new car   \n",
       "736             0.0        0<=X<200  existing paid  used car   \n",
       "922             0.0              <0  existing paid  radio/tv   \n",
       "511             0.0     no checking  existing paid  used car   \n",
       "\n",
       "       savings_status employment     personal_status other_parties  \\\n",
       "636       500<=X<1000     4<=X<7  female div/dep/mar          none   \n",
       "182  no known savings     1<=X<4         male single          none   \n",
       "736              <100     1<=X<4  female div/dep/mar          none   \n",
       "922              <100         <1  female div/dep/mar          none   \n",
       "511              <100     1<=X<4         male single          none   \n",
       "\n",
       "    property_magnitude other_payment_plans   housing  \\\n",
       "636                car                none       own   \n",
       "182     life insurance                none       own   \n",
       "736                car                none      rent   \n",
       "922     life insurance                none      rent   \n",
       "511  no known property                none  for free   \n",
       "\n",
       "                           job  \n",
       "636                    skilled  \n",
       "182         unskilled resident  \n",
       "736  high qualif/self emp/mgmt  \n",
       "922                    skilled  \n",
       "511  high qualif/self emp/mgmt  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4f3e1e91",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:16:28.124 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json\n",
      "2022-08-31 14:16:28.134 | INFO     | transtab.modeling_transtab:load:523 - missing keys: []\n",
      "2022-08-31 14:16:28.135 | INFO     | transtab.modeling_transtab:load:524 - unexpected keys: []\n",
      "2022-08-31 14:16:28.136 | INFO     | transtab.modeling_transtab:load:525 - load model from ./checkpoint\n"
     ]
    }
   ],
   "source": [
    "# Second, if we only want to the embeded token level embeddings (embeddings before going to transformers)\n",
    "enc = transtab.build_encoder(\n",
    "    binary_columns=bin_cols,\n",
    "    checkpoint = './checkpoint',\n",
    "    num_layer = 0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "39a0172b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([700, 85, 128])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.1370,  0.0427, -0.0106,  ..., -0.0806,  0.0518, -0.1315],\n",
       "         [ 0.0657,  0.0341, -0.0128,  ..., -0.0207,  0.0102, -0.0046],\n",
       "         [ 0.1494,  0.4290,  0.2463,  ...,  0.1992, -0.0848, -0.0840],\n",
       "         ...,\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268]],\n",
       "\n",
       "        [[ 0.1204,  0.0388, -0.0098,  ..., -0.0738,  0.0400, -0.1099],\n",
       "         [ 0.0752,  0.0383, -0.0145,  ..., -0.0174,  0.0190, -0.0085],\n",
       "         [ 0.1494,  0.4290,  0.2463,  ...,  0.1992, -0.0848, -0.0840],\n",
       "         ...,\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268]]],\n",
       "       device='cuda:0', grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = enc(df)\n",
    "print(output['embedding'].shape)\n",
    "output['embedding'][:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55936f1e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: examples/transfer_learning.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "134f979d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "42c60011",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n",
      "########################################\n",
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dd62a8df24d14e22a69d77088bd1b220",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 0.574102\n",
      "epoch: 0, train loss: 3.9759, lr: 0.000100, spent: 0.4 secs\n",
      "epoch: 1, test val_loss: 0.565162\n",
      "epoch: 1, train loss: 3.7812, lr: 0.000100, spent: 0.9 secs\n",
      "epoch: 2, test val_loss: 0.576745\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 2, train loss: 3.6560, lr: 0.000100, spent: 1.1 secs\n",
      "epoch: 3, test val_loss: 0.566665\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 3, train loss: 3.6539, lr: 0.000100, spent: 1.4 secs\n",
      "epoch: 4, test val_loss: 0.548929\n",
      "epoch: 4, train loss: 3.6118, lr: 0.000100, spent: 1.7 secs\n",
      "epoch: 5, test val_loss: 0.545800\n",
      "epoch: 5, train loss: 3.5634, lr: 0.000100, spent: 2.2 secs\n",
      "epoch: 6, test val_loss: 0.545121\n",
      "epoch: 6, train loss: 3.5035, lr: 0.000100, spent: 2.4 secs\n",
      "epoch: 7, test val_loss: 0.529130\n",
      "epoch: 7, train loss: 3.4372, lr: 0.000100, spent: 2.7 secs\n",
      "epoch: 8, test val_loss: 0.525149\n",
      "epoch: 8, train loss: 3.3768, lr: 0.000100, spent: 3.0 secs\n",
      "epoch: 9, test val_loss: 0.518042\n",
      "epoch: 9, train loss: 3.3204, lr: 0.000100, spent: 3.5 secs\n",
      "epoch: 10, test val_loss: 0.508209\n",
      "epoch: 10, train loss: 3.2816, lr: 0.000100, spent: 3.8 secs\n",
      "epoch: 11, test val_loss: 0.497027\n",
      "epoch: 11, train loss: 3.1952, lr: 0.000100, spent: 4.1 secs\n",
      "epoch: 12, test val_loss: 0.495085\n",
      "epoch: 12, train loss: 3.1852, lr: 0.000100, spent: 4.6 secs\n",
      "epoch: 13, test val_loss: 0.479123\n",
      "epoch: 13, train loss: 3.0853, lr: 0.000100, spent: 4.9 secs\n",
      "epoch: 14, test val_loss: 0.492737\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 14, train loss: 3.0682, lr: 0.000100, spent: 5.2 secs\n",
      "epoch: 15, test val_loss: 0.477266\n",
      "epoch: 15, train loss: 2.9653, lr: 0.000100, spent: 5.5 secs\n",
      "epoch: 16, test val_loss: 0.503946\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 16, train loss: 2.9797, lr: 0.000100, spent: 5.7 secs\n",
      "epoch: 17, test val_loss: 0.484869\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 17, train loss: 2.9767, lr: 0.000100, spent: 6.0 secs\n",
      "epoch: 18, test val_loss: 0.467354\n",
      "epoch: 18, train loss: 2.8925, lr: 0.000100, spent: 6.5 secs\n",
      "epoch: 19, test val_loss: 0.471429\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 19, train loss: 2.8963, lr: 0.000100, spent: 6.7 secs\n",
      "epoch: 20, test val_loss: 0.460370\n",
      "epoch: 20, train loss: 2.8847, lr: 0.000100, spent: 7.0 secs\n",
      "epoch: 21, test val_loss: 0.498306\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 21, train loss: 2.8389, lr: 0.000100, spent: 7.4 secs\n",
      "epoch: 22, test val_loss: 0.441738\n",
      "epoch: 22, train loss: 2.8077, lr: 0.000100, spent: 7.7 secs\n",
      "epoch: 23, test val_loss: 0.479452\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 23, train loss: 2.8506, lr: 0.000100, spent: 8.0 secs\n",
      "epoch: 24, test val_loss: 0.450146\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 24, train loss: 2.7006, lr: 0.000100, spent: 8.5 secs\n",
      "epoch: 25, test val_loss: 0.460931\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 25, train loss: 2.7361, lr: 0.000100, spent: 8.7 secs\n",
      "epoch: 26, test val_loss: 0.482305\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 26, train loss: 2.6959, lr: 0.000100, spent: 9.0 secs\n",
      "epoch: 27, test val_loss: 0.440060\n",
      "epoch: 27, train loss: 2.7485, lr: 0.000100, spent: 9.3 secs\n",
      "epoch: 28, test val_loss: 0.450090\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 28, train loss: 2.7765, lr: 0.000100, spent: 9.6 secs\n",
      "epoch: 29, test val_loss: 0.472720\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 29, train loss: 2.6344, lr: 0.000100, spent: 9.8 secs\n",
      "epoch: 30, test val_loss: 0.438471\n",
      "epoch: 30, train loss: 2.5639, lr: 0.000100, spent: 10.3 secs\n",
      "epoch: 31, test val_loss: 0.498057\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 31, train loss: 2.7224, lr: 0.000100, spent: 10.6 secs\n",
      "epoch: 32, test val_loss: 0.463493\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 32, train loss: 2.6888, lr: 0.000100, spent: 11.0 secs\n",
      "epoch: 33, test val_loss: 0.435828\n",
      "epoch: 33, train loss: 2.6895, lr: 0.000100, spent: 11.3 secs\n",
      "epoch: 34, test val_loss: 0.495953\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 34, train loss: 2.5385, lr: 0.000100, spent: 11.6 secs\n",
      "epoch: 35, test val_loss: 0.444737\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 35, train loss: 2.5663, lr: 0.000100, spent: 12.1 secs\n",
      "epoch: 36, test val_loss: 0.449832\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 36, train loss: 2.6015, lr: 0.000100, spent: 12.4 secs\n",
      "epoch: 37, test val_loss: 0.441197\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 37, train loss: 2.5011, lr: 0.000100, spent: 12.6 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-10-05 08:35:04.023 | INFO     | transtab.trainer:train:136 - load best at last from ./checkpoint\n",
      "2022-10-05 08:35:04.042 | INFO     | transtab.trainer:save_model:243 - saving model checkpoint to ./checkpoint\n",
      "2022-10-05 08:35:04.167 | INFO     | transtab.trainer:train:141 - training complete, cost 13.1 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 38, test val_loss: 0.503903\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# load a dataset and start vanilla supervised training\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols = transtab.load_data(['credit-g', 'credit-approval'])\n",
    "\n",
    "# build transtab classifier model\n",
    "model = transtab.build_classifier(cat_cols, num_cols, bin_cols)\n",
    "\n",
    "# start training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint',\n",
    "    'batch_size':128,\n",
    "    'lr':1e-4,\n",
    "    'weight_decay':1e-4,\n",
    "    }\n",
    "transtab.train(model, trainset[0], valset[0], **training_arguments)\n",
    "\n",
    "# save model\n",
    "model.save('./ckpt/pretrained')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6bdc971",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-10-05 08:35:04.352 | INFO     | transtab.modeling_transtab:load:773 - missing keys: []\n",
      "2022-10-05 08:35:04.354 | INFO     | transtab.modeling_transtab:load:774 - unexpected keys: []\n",
      "2022-10-05 08:35:04.355 | INFO     | transtab.modeling_transtab:load:775 - load model from ./ckpt/pretrained\n",
      "2022-10-05 08:35:04.370 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./ckpt/pretrained/extractor/extractor.json\n",
      "2022-10-05 08:35:04.372 | INFO     | transtab.modeling_transtab:update:832 - Build a new classifier with num 2 classes outputs, need further finetune to work.\n"
     ]
    }
   ],
   "source": [
    "# now let's use another data and try to leverage the pretrained model for finetuning\n",
    "# here we have loaded the required data `credit-approval` before, no need to load again.\n",
    "\n",
    "# load the pretrained model\n",
    "model.load('./ckpt/pretrained')\n",
    "\n",
    "# update model's categorical/numerical/binary column dict\n",
    "# need to specify the number of classes if the new dataset has different # of classes from the \n",
    "# pretrained one.\n",
    "model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols, 'num_class':2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f399d02e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9b64fb45097e4061af5a0186c17d98a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test auc: 0.282251\n",
      "epoch: 0, train loss: 3.3862, lr: 0.000200, spent: 0.2 secs\n",
      "epoch: 1, test auc: 0.865801\n",
      "epoch: 1, train loss: 2.8794, lr: 0.000200, spent: 0.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, test auc: 0.865801\n",
      "epoch: 2, train loss: 2.5943, lr: 0.000200, spent: 0.7 secs\n",
      "epoch: 3, test auc: 0.865801\n",
      "epoch: 3, train loss: 2.4300, lr: 0.000200, spent: 0.8 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, test auc: 0.872727\n",
      "epoch: 4, train loss: 2.2617, lr: 0.000200, spent: 1.0 secs\n",
      "epoch: 5, test auc: 0.879654\n",
      "epoch: 5, train loss: 2.0867, lr: 0.000200, spent: 1.1 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, test auc: 0.880519\n",
      "epoch: 6, train loss: 1.9774, lr: 0.000200, spent: 1.3 secs\n",
      "epoch: 7, test auc: 0.883117\n",
      "epoch: 7, train loss: 1.8739, lr: 0.000200, spent: 1.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, test auc: 0.889177\n",
      "epoch: 8, train loss: 1.8919, lr: 0.000200, spent: 1.5 secs\n",
      "epoch: 9, test auc: 0.890909\n",
      "epoch: 9, train loss: 1.8794, lr: 0.000200, spent: 1.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, test auc: 0.896970\n",
      "epoch: 10, train loss: 1.8456, lr: 0.000200, spent: 2.0 secs\n",
      "epoch: 11, test auc: 0.897835\n",
      "epoch: 11, train loss: 1.8213, lr: 0.000200, spent: 2.2 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 12, test auc: 0.896104\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 12, train loss: 1.8219, lr: 0.000200, spent: 2.3 secs\n",
      "epoch: 13, test auc: 0.903896\n",
      "epoch: 13, train loss: 1.7924, lr: 0.000200, spent: 2.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 14, test auc: 0.905628\n",
      "epoch: 14, train loss: 1.7964, lr: 0.000200, spent: 2.6 secs\n",
      "epoch: 15, test auc: 0.904762\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 15, train loss: 1.7641, lr: 0.000200, spent: 2.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 16, test auc: 0.904762\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 16, train loss: 1.7788, lr: 0.000200, spent: 2.8 secs\n",
      "epoch: 17, test auc: 0.909091\n",
      "epoch: 17, train loss: 1.7456, lr: 0.000200, spent: 2.9 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 18, test auc: 0.910823\n",
      "epoch: 18, train loss: 1.7438, lr: 0.000200, spent: 3.3 secs\n",
      "epoch: 19, test auc: 0.912554\n",
      "epoch: 19, train loss: 1.7569, lr: 0.000200, spent: 3.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 20, test auc: 0.912554\n",
      "epoch: 20, train loss: 1.7533, lr: 0.000200, spent: 3.5 secs\n",
      "epoch: 21, test auc: 0.915152\n",
      "epoch: 21, train loss: 1.7439, lr: 0.000200, spent: 3.7 secs\n",
      "epoch: 22, test auc: 0.915152\n",
      "epoch: 22, train loss: 1.7020, lr: 0.000200, spent: 3.9 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 23, test auc: 0.916883\n",
      "epoch: 23, train loss: 1.7017, lr: 0.000200, spent: 4.0 secs\n",
      "epoch: 24, test auc: 0.917749\n",
      "epoch: 24, train loss: 1.6625, lr: 0.000200, spent: 4.1 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 25, test auc: 0.918615\n",
      "epoch: 25, train loss: 1.6432, lr: 0.000200, spent: 4.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 26, test auc: 0.922944\n",
      "epoch: 26, train loss: 1.6299, lr: 0.000200, spent: 4.7 secs\n",
      "epoch: 27, test auc: 0.922944\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 27, train loss: 1.6158, lr: 0.000200, spent: 4.8 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 28, test auc: 0.925541\n",
      "epoch: 28, train loss: 1.5971, lr: 0.000200, spent: 4.9 secs\n",
      "epoch: 29, test auc: 0.926407\n",
      "epoch: 29, train loss: 1.5771, lr: 0.000200, spent: 5.0 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 30, test auc: 0.927273\n",
      "epoch: 30, train loss: 1.5763, lr: 0.000200, spent: 5.2 secs\n",
      "epoch: 31, test auc: 0.933333\n",
      "epoch: 31, train loss: 1.6021, lr: 0.000200, spent: 5.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 32, test auc: 0.936797\n",
      "epoch: 32, train loss: 1.5513, lr: 0.000200, spent: 5.5 secs\n",
      "epoch: 33, test auc: 0.938528\n",
      "epoch: 33, train loss: 1.5160, lr: 0.000200, spent: 5.6 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 34, test auc: 0.938528\n",
      "epoch: 34, train loss: 1.5250, lr: 0.000200, spent: 5.8 secs\n",
      "epoch: 35, test auc: 0.938528\n",
      "epoch: 35, train loss: 1.4732, lr: 0.000200, spent: 6.0 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 36, test auc: 0.934199\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 36, train loss: 1.4738, lr: 0.000200, spent: 6.1 secs\n",
      "epoch: 37, test auc: 0.934199\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 37, train loss: 1.4667, lr: 0.000200, spent: 6.2 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 38, test auc: 0.933333\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 38, train loss: 1.4209, lr: 0.000200, spent: 6.3 secs\n",
      "epoch: 39, test auc: 0.933333\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 39, train loss: 1.4371, lr: 0.000200, spent: 6.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "2022-10-05 08:35:10.982 | INFO     | transtab.trainer:train:136 - load best at last from ./checkpoint\n",
      "2022-10-05 08:35:10.994 | INFO     | transtab.trainer:save_model:243 - saving model checkpoint to ./checkpoint\n",
      "2022-10-05 08:35:11.142 | INFO     | transtab.trainer:train:141 - training complete, cost 6.7 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 40, test auc: 0.929870\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# start training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'auc',\n",
    "    'eval_less_is_better':False,\n",
    "    'output_dir':'./checkpoint',\n",
    "    'batch_size':128,\n",
    "    'lr':2e-4,\n",
    "    }\n",
    "\n",
    "transtab.train(model, trainset[1], valset[1], **training_arguments)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3aa87021",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc 0.95 mean/interval 0.8757(0.05)\n",
      "0.8807749627421758\n"
     ]
    }
   ],
   "source": [
    "# evaluation\n",
    "x_test, y_test = testset[1]\n",
    "ypred = transtab.predict(model, x_test)\n",
    "transtab.evaluate(ypred, y_test, metric='auc')\n",
    "\n",
    "from sklearn.metrics import roc_auc_score\n",
    "print(roc_auc_score(y_test, ypred))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('pytrial': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "2f00ab411e3cfe281b54106f98420bd06c3920b043d7b3741a63d2a4ac576305"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: examples/transfer_learning_regressor.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "739e0cff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "134f979d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "668517ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a64015e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: openml in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (0.15.1)\n",
      "Requirement already satisfied: liac-arff>=2.4.0 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (2.5.0)\n",
      "Requirement already satisfied: xmltodict in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (0.14.2)\n",
      "Requirement already satisfied: requests in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (2.32.3)\n",
      "Requirement already satisfied: scikit-learn>=0.18 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (1.6.1)\n",
      "Requirement already satisfied: python-dateutil in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (2.9.0.post0)\n",
      "Requirement already satisfied: pandas>=1.0.0 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (2.2.3)\n",
      "Requirement already satisfied: scipy>=0.13.3 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (1.15.2)\n",
      "Requirement already satisfied: numpy>=1.6.2 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (2.2.3)\n",
      "Requirement already satisfied: minio in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (7.2.15)\n",
      "Requirement already satisfied: pyarrow in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (19.0.1)\n",
      "Requirement already satisfied: tqdm in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (4.67.1)\n",
      "Requirement already satisfied: packaging in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from openml) (24.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from pandas>=1.0.0->openml) (2025.1)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from pandas>=1.0.0->openml) (2025.1)\n",
      "Requirement already satisfied: six>=1.5 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from python-dateutil->openml) (1.17.0)\n",
      "Requirement already satisfied: joblib>=1.2.0 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from scikit-learn>=0.18->openml) (1.4.2)\n",
      "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from scikit-learn>=0.18->openml) (3.5.0)\n",
      "Requirement already satisfied: certifi in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from minio->openml) (2025.1.31)\n",
      "Requirement already satisfied: urllib3 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from minio->openml) (2.3.0)\n",
      "Requirement already satisfied: argon2-cffi in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from minio->openml) (23.1.0)\n",
      "Requirement already satisfied: pycryptodome in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from minio->openml) (3.21.0)\n",
      "Requirement already satisfied: typing-extensions in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from minio->openml) (4.12.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from requests->openml) (3.4.1)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from requests->openml) (3.10)\n",
      "Requirement already satisfied: argon2-cffi-bindings in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from argon2-cffi->minio->openml) (21.2.0)\n",
      "Requirement already satisfied: cffi>=1.0.1 in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from argon2-cffi-bindings->argon2-cffi->minio->openml) (1.17.1)\n",
      "Requirement already satisfied: pycparser in /home/zifengw2/miniconda3/envs/digitaltwin/lib/python3.10/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->minio->openml) (2.22)\n",
      "########################################\n"
     ]
    },
    {
     "ename": "ImportError",
     "evalue": "OpenML is required for this functionality. Please install it with: pip install openml",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# load a dataset and start vanilla supervised training\u001b[39;00m\n\u001b[1;32m      2\u001b[0m get_ipython()\u001b[38;5;241m.\u001b[39msystem(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpip install openml\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \u001b[38;5;241m=\u001b[39m \u001b[43mtranstab\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcredit-g\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcredit-approval\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/github/transtab/transtab/dataset.py:95\u001b[0m, in \u001b[0;36mload_data\u001b[0;34m(dataname, dataset_config, encode_cat, data_cut, seed)\u001b[0m\n\u001b[1;32m     92\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m dataname_ \u001b[38;5;129;01min\u001b[39;00m dataname:\n\u001b[1;32m     93\u001b[0m     data_config \u001b[38;5;241m=\u001b[39m dataset_config\u001b[38;5;241m.\u001b[39mget(dataname_, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m     94\u001b[0m     allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \u001b[38;5;241m=\u001b[39m \\\n\u001b[0;32m---> 95\u001b[0m         \u001b[43mload_single_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataname_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_cat\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_cat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_cut\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_cut\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mseed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     96\u001b[0m     num_col_list\u001b[38;5;241m.\u001b[39mextend(num_cols)\n\u001b[1;32m     97\u001b[0m     cat_col_list\u001b[38;5;241m.\u001b[39mextend(cat_cols)\n",
      "File \u001b[0;32m~/github/transtab/transtab/dataset.py:159\u001b[0m, in \u001b[0;36mload_single_data\u001b[0;34m(dataname, dataset_config, encode_cat, data_cut, seed)\u001b[0m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    158\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _has_openml:\n\u001b[0;32m--> 159\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m    160\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOpenML is required for this functionality. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    161\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease install it with: pip install openml\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    162\u001b[0m         )\n\u001b[1;32m    163\u001b[0m     dataset \u001b[38;5;241m=\u001b[39m openml\u001b[38;5;241m.\u001b[39mdatasets\u001b[38;5;241m.\u001b[39mget_dataset(dataname)\n\u001b[1;32m    164\u001b[0m     X,y,categorical_indicator, attribute_names \u001b[38;5;241m=\u001b[39m dataset\u001b[38;5;241m.\u001b[39mget_data(dataset_format\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdataframe\u001b[39m\u001b[38;5;124m'\u001b[39m, target\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mdefault_target_attribute)\n",
      "\u001b[0;31mImportError\u001b[0m: OpenML is required for this functionality. Please install it with: pip install openml"
     ]
    }
   ],
   "source": [
    "# load a dataset and start vanilla supervised training\n",
    "# !pip install openml\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols = transtab.load_data(['credit-g', 'credit-approval'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "521fb369",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'trainset' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m trainset_reg \u001b[38;5;241m=\u001b[39m [(\u001b[43mtrainset\u001b[49m[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;241m0\u001b[39m], pd\u001b[38;5;241m.\u001b[39mSeries(np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(trainset[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]))), (trainset[\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m0\u001b[39m], pd\u001b[38;5;241m.\u001b[39mSeries(np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(trainset[\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m])))]\n",
      "\u001b[0;31mNameError\u001b[0m: name 'trainset' is not defined"
     ]
    }
   ],
   "source": [
    "trainset_reg = [(trainset[0][0], pd.Series(np.random.randn(trainset[0][0].shape[0]))), (trainset[1][0], pd.Series(np.random.randn(trainset[1][0].shape[0])))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "cadc940f",
   "metadata": {},
   "outputs": [],
   "source": [
    "valset_reg = [(valset[0][0], pd.Series(np.random.randn(valset[0][0].shape[0]))), (valset[1][0], pd.Series(np.random.randn(valset[1][0].shape[0])))]\n",
    "testset_reg = [(testset[0][0], pd.Series(np.random.randn(testset[0][0].shape[0]))), (testset[1][0], pd.Series(np.random.randn(testset[1][0].shape[0])))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "42c60011",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   2%|▏         | 1/50 [00:01<01:33,  1.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 1.372377\n",
      "epoch: 0, train loss: 6.7940, lr: 0.000100, spent: 1.9 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   4%|▍         | 2/50 [00:03<01:22,  1.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, test val_loss: 1.184756\n",
      "epoch: 1, train loss: 6.0480, lr: 0.000100, spent: 3.5 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   6%|▌         | 3/50 [00:05<01:18,  1.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, test val_loss: 1.194661\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 2, train loss: 6.1002, lr: 0.000100, spent: 5.1 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   8%|▊         | 4/50 [00:06<01:14,  1.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, test val_loss: 1.218926\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 3, train loss: 5.8850, lr: 0.000100, spent: 6.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  10%|█         | 5/50 [00:08<01:12,  1.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, test val_loss: 1.198663\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 4, train loss: 5.9642, lr: 0.000100, spent: 8.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  12%|█▏        | 6/50 [00:09<01:10,  1.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, test val_loss: 1.205427\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 5, train loss: 5.8004, lr: 0.000100, spent: 9.9 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  12%|█▏        | 6/50 [00:11<01:24,  1.92s/it]\n",
      "\u001b[32m2024-03-08 16:32:55.367\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.trainer\u001b[0m:\u001b[36mtrain\u001b[0m:\u001b[36m136\u001b[0m - \u001b[1mload best at last from ./checkpoint\u001b[0m\n",
      "\u001b[32m2024-03-08 16:32:55.378\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.trainer\u001b[0m:\u001b[36msave_model\u001b[0m:\u001b[36m247\u001b[0m - \u001b[1msaving model checkpoint to ./checkpoint\u001b[0m\n",
      "\u001b[32m2024-03-08 16:32:55.471\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.trainer\u001b[0m:\u001b[36mtrain\u001b[0m:\u001b[36m141\u001b[0m - \u001b[1mtraining complete, cost 11.6 secs.\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, test val_loss: 1.198403\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# build transtab classifier model\n",
    "model = transtab.build_regressor(cat_cols, num_cols, bin_cols, device='cpu')\n",
    "\n",
    "# start training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint',\n",
    "    'batch_size':128,\n",
    "    'lr':1e-4,\n",
    "    'weight_decay':1e-4,\n",
    "    }\n",
    "transtab.train(model, trainset_reg[0], valset_reg[0], **training_arguments)\n",
    "\n",
    "# save model\n",
    "model.save('./ckpt/pretrained')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "d6bdc971",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2024-03-08 16:33:11.448\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.modeling_transtab\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m787\u001b[0m - \u001b[1mmissing keys: []\u001b[0m\n",
      "\u001b[32m2024-03-08 16:33:11.448\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.modeling_transtab\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m788\u001b[0m - \u001b[1munexpected keys: []\u001b[0m\n",
      "\u001b[32m2024-03-08 16:33:11.449\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.modeling_transtab\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m789\u001b[0m - \u001b[1mload model from ./ckpt/pretrained\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2024-03-08 16:33:11.468\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.modeling_transtab\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m222\u001b[0m - \u001b[1mload feature extractor from ./ckpt/pretrained/extractor/extractor.json\u001b[0m\n",
      "\u001b[32m2024-03-08 16:33:11.470\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.modeling_transtab\u001b[0m:\u001b[36m_adapt_to_new_num_class\u001b[0m:\u001b[36m886\u001b[0m - \u001b[1mBuild a new classifier with num 2 classes outputs, need further finetune to work.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# now let's use another data and try to leverage the pretrained model for finetuning\n",
    "# here we have loaded the required data `credit-approval` before, no need to load again.\n",
    "\n",
    "# load the pretrained model\n",
    "model.load('./ckpt/pretrained')\n",
    "\n",
    "# update model's categorical/numerical/binary column dict\n",
    "# need to specify the number of classes if the new dataset has different # of classes from the \n",
    "# pretrained one.\n",
    "model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols, 'num_class':2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "f399d02e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   2%|▏         | 1/50 [00:00<00:37,  1.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test mse: 0.814842\n",
      "epoch: 0, train loss: 2.9249, lr: 0.000200, spent: 0.8 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   4%|▍         | 2/50 [00:01<00:31,  1.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, test mse: 0.803411\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 1, train loss: 0.1003, lr: 0.000200, spent: 1.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   6%|▌         | 3/50 [00:01<00:29,  1.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, test mse: 0.802998\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 2, train loss: -0.3084, lr: 0.000200, spent: 2.0 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   8%|▊         | 4/50 [00:02<00:28,  1.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, test mse: 0.802881\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 3, train loss: -0.3803, lr: 0.000200, spent: 2.6 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  10%|█         | 5/50 [00:03<00:28,  1.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, test mse: 0.802826\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 4, train loss: -0.2638, lr: 0.000200, spent: 3.2 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  10%|█         | 5/50 [00:03<00:33,  1.34it/s]\n",
      "\u001b[32m2024-03-08 16:37:52.614\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.trainer\u001b[0m:\u001b[36mtrain\u001b[0m:\u001b[36m136\u001b[0m - \u001b[1mload best at last from ./checkpoint\u001b[0m\n",
      "\u001b[32m2024-03-08 16:37:52.621\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.trainer\u001b[0m:\u001b[36msave_model\u001b[0m:\u001b[36m247\u001b[0m - \u001b[1msaving model checkpoint to ./checkpoint\u001b[0m\n",
      "\u001b[32m2024-03-08 16:37:52.718\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtranstab.trainer\u001b[0m:\u001b[36mtrain\u001b[0m:\u001b[36m141\u001b[0m - \u001b[1mtraining complete, cost 3.9 secs.\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, test mse: 0.802803\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# start training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'mse',\n",
    "    'eval_less_is_better':False,\n",
    "    'output_dir':'./checkpoint',\n",
    "    'batch_size':128,\n",
    "    'lr':2e-4,\n",
    "    }\n",
    "\n",
    "transtab.train(model, trainset_reg[1], valset_reg[1], **training_arguments)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "3aa87021",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9819256995837686\n"
     ]
    }
   ],
   "source": [
    "# evaluation\n",
    "x_test, y_test = testset_reg[1]\n",
    "ypred = transtab.predict(model, x_test, y_test)\n",
    "transtab.evaluate(ypred, y_test, metric='mse')\n",
    "\n",
    "from sklearn.metrics import mean_squared_error\n",
    "print(mean_squared_error(y_test, ypred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4bf1d31",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digitaltwin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: pypi_build_commands.txt
================================================
# This is a command list for building pypi packages
python setup.py sdist bdist_wheel

twine check dist/*

# upload to pypi-test
python -m twine upload --repository-url https://test.pypi.org/legacy/ dist/*

# install from test-pypi
pip install --index-url https://test.pypi.org/simple/ transtab==0.0.2c

# upload to pypi
twine upload dist/*


================================================
FILE: requirements.txt
================================================
numpy
scikit_learn
setuptools
transformers<=4.30.0
tqdm
pandas>=1.3.0
openml>=0.10.0


================================================
FILE: setup.py
================================================
import os
import setuptools

this_directory = os.path.abspath(os.path.dirname(__file__))

with open("README.md", "r") as f:
    long_description = f.read()

# read the contents of requirements.txt
with open(os.path.join(this_directory, 'requirements.txt'),
          encoding='utf-8') as f:
    requirements = f.read().splitlines()

setuptools.setup(
    name = 'transtab',
    version = '0.0.7',
    author = 'Zifeng Wang',
    author_email = 'zifengw2@illinois.edu',
    description = 'A flexible tabular prediction model that handles variable-column input tables.',
    url = 'https://github.com/RyanWangZf/transtab',
    keywords=['tabular data', 'machine learning', 'data mining', 'data science'],
    long_description=long_description,
    long_description_content_type='text/markdown',
    packages=setuptools.find_packages(exclude=['test']),
    install_requires=requirements,
    classifiers=[
        "Programming Language :: Python :: 3",
        "Programming Language :: Python :: 3.7",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "License :: OSI Approved :: BSD License",
        "Operating System :: OS Independent",
    ],
)


================================================
FILE: transtab/__init__.py
================================================
name = 'transtab'
version = '0.0.6'

from .transtab import *


================================================
FILE: transtab/constants.py
================================================
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.json"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
WEIGHTS_NAME = "pytorch_model.bin"
TOKENIZER_DIR = 'tokenizer'
EXTRACTOR_STATE_DIR = 'extractor'
EXTRACTOR_STATE_NAME = 'extractor.json'
INPUT_ENCODER_NAME = 'input_encoder.bin'

================================================
FILE: transtab/dataset.py
================================================
import os
import pdb

import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, MinMaxScaler
from sklearn.model_selection import train_test_split

try:
    import openml
    _has_openml = True
except ImportError:
    _has_openml = False

import logging
logger = logging.getLogger(__name__)

OPENML_DATACONFIG = {
    'credit-g': {'bin': ['own_telephone', 'foreign_worker']},
}

EXAMPLE_DATACONFIG = {
    "example": {
        "bin": ["bin1", "bin2"],
        "cat": ["cat1", "cat2"],
        "num": ["num1", "num2"],
        "cols": ["bin1", "bin2", "cat1", "cat2", "num1", "num2"],
        "binary_indicator": ["1", "yes", "true", "positive", "t", "y"],
        "data_split_idx": {
            "train":[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            "val":[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
            "test":[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        }
    }
}

def load_data(dataname, dataset_config=None, encode_cat=False, data_cut=None, seed=123):
    '''Load datasets from the local device or from openml.datasets.

    Parameters
    ----------
    dataname: str or int
        the dataset name/index intended to be loaded from openml. or the directory to the local dataset.
    
    dataset_config: dict
        the dataset configuration to specify for loading. Please note that this variable will
        override the configuration loaded from the local files or from the openml.dataset.
    
    encode_cat: bool
        whether encoder the categorical/binary columns to be discrete indices, keep False for TransTab models.
    
    data_cut: int
        how many to split the raw tables into partitions equally; set None will not execute partition.

    seed: int
        the random seed set to ensure the fixed train/val/test split.

    Returns
    -------
    all_list: list or tuple
        the complete dataset, be (x,y) or [(x1,y1),(x2,y2),...].

    train_list: list or tuple
        the train dataset, be (x,y) or [(x1,y1),(x2,y2),...].

    val_list: list or tuple
        the validation dataset, be (x,y) or [(x1,y1),(x2,y2),...].

    test_list: list
        the test dataset, be (x,y) or [(x1,y1),(x2,y2),...].

    cat_col_list: list
        the list of categorical column names.

    num_col_list: list
        the list of numerical column names.

    bin_col_list: list
        the list of binary column names.

    '''
    if dataset_config is None: dataset_config = OPENML_DATACONFIG
    if isinstance(dataname, str):
        # load a single tabular data
        return load_single_data(dataname=dataname, dataset_config=dataset_config, encode_cat=encode_cat, data_cut=data_cut, seed=seed)
    
    if isinstance(dataname, list):
        # load a list of datasets, combine together and outputs
        num_col_list, cat_col_list, bin_col_list = [], [], []
        all_list = []
        train_list, val_list, test_list = [], [], []
        for dataname_ in dataname:
            data_config = dataset_config.get(dataname_, None)
            allset, trainset, valset, testset, cat_cols, num_cols, bin_cols = \
                load_single_data(dataname_, dataset_config=data_config, encode_cat=encode_cat, data_cut=data_cut, seed=seed)
            num_col_list.extend(num_cols)
            cat_col_list.extend(cat_cols)
            bin_col_list.extend(bin_cols)
            all_list.append(allset)
            train_list.append(trainset)
            val_list.append(valset)
            test_list.append(testset)
        return all_list, train_list, val_list, test_list, cat_col_list, num_col_list, bin_col_list

def load_single_data(dataname, dataset_config=None, encode_cat=False, data_cut=None, seed=123):
    '''Load tabular dataset from local or from openml public database.
    args:
        dataname: Can either be the data directory on `./data/{dataname}` or the dataname which can be found from the openml database.
        dataset_config: 
            A dict like {'dataname':{'bin': [col1,col2,...]}} to indicate the binary columns for the data obtained from openml.
            Also can be used to {'dataname':{'cols':[col1,col2,..]}} to assign a new set of column names to the data
        encode_cat:  Set `False` if we are using transtab, otherwise we set it True to encode categorical values into indexes.
        data_cut: The number of cuts of the training set. Cut is performed on both rows and columns.
    outputs:
        allset: (X,y) that contains all samples of this dataset
        trainset, valset, testset: the train/val/test split
        num_cols, cat_cols, bin_cols: the list of numerical/categorical/binary column names
    '''
    print('####'*10)
    if os.path.exists(dataname):
        print(f'load from local data dir {dataname}')
        filename = os.path.join(dataname, 'data_processed.csv')
        df = pd.read_csv(filename, index_col=0)
        y = df['target_label']
        X = df.drop(['target_label'],axis=1)
        all_cols = [col.lower() for col in X.columns.tolist()]

        X.columns = all_cols
        attribute_names = all_cols
        ftfile = os.path.join(dataname, 'numerical_feature.txt')
        if os.path.exists(ftfile):
            with open(ftfile,'r') as f: num_cols = [x.strip().lower() for x in f.readlines()]
        else:
            num_cols = []
        bnfile = os.path.join(dataname, 'binary_feature.txt')
        if os.path.exists(bnfile):
            with open(bnfile,'r') as f: bin_cols = [x.strip().lower() for x in f.readlines()]
        else:
            bin_cols = []
        cat_cols = [col for col in all_cols if col not in num_cols and col not in bin_cols]

        # update cols by loading dataset_config
        if dataset_config is not None:
            if 'columns' in dataset_config:
                new_cols = dataset_config['columns']
                X.columns = new_cols

            if 'bin' in dataset_config:
                bin_cols = dataset_config['bin']
            
            if 'cat' in dataset_config:
                cat_cols = dataset_config['cat']

            if 'num' in dataset_config:
                num_cols = dataset_config['num']
        
    else:
        if not _has_openml:
            raise ImportError(
                "OpenML is required for this functionality. "
                "Please install it with: pip install openml"
            )
        dataset = openml.datasets.get_dataset(dataname)
        X,y,categorical_indicator, attribute_names = dataset.get_data(dataset_format='dataframe', target=dataset.default_target_attribute)
        
        if isinstance(dataname, int):
            openml_list = openml.datasets.list_datasets(output_format="dataframe")  # returns a dict
            dataname = openml_list.loc[openml_list.did == dataname].name.values[0]
        else:
            openml_list = openml.datasets.list_datasets(output_format="dataframe")  # returns a dict
            print(f'openml data index: {openml_list.loc[openml_list.name == dataname].index[0]}')
        
        print(f'load data from {dataname}')

        # drop cols which only have one unique value
        drop_cols = [col for col in attribute_names if X[col].nunique()<=1]

        all_cols = np.array(attribute_names)
        categorical_indicator = np.array(categorical_indicator)
        cat_cols = [col for col in all_cols[categorical_indicator] if col not in drop_cols]
        num_cols = [col for col in all_cols[~categorical_indicator] if col not in drop_cols]
        all_cols = [col for col in all_cols if col not in drop_cols]
        
        if dataset_config is not None:
            if 'bin' in dataset_config: bin_cols = [c for c in cat_cols if c in dataset_config['bin']]
        else: bin_cols = []
        cat_cols = [c for c in cat_cols if c not in bin_cols]

        # encode target label
        y = LabelEncoder().fit_transform(y.values)
        y = pd.Series(y,index=X.index)

    # start processing features
    # process num
    if len(num_cols) > 0:
        for col in num_cols: X[col].fillna(X[col].mode()[0], inplace=True)
        X[num_cols] = MinMaxScaler().fit_transform(X[num_cols])

    if len(cat_cols) > 0:
        for col in cat_cols: X[col].fillna(X[col].mode()[0], inplace=True)
        # process cate
        if encode_cat:
            X[cat_cols] = OrdinalEncoder().fit_transform(X[cat_cols])
        else:
            X[cat_cols] = X[cat_cols].astype(str)

    if len(bin_cols) > 0:
        for col in bin_cols: X[col].fillna(X[col].mode()[0], inplace=True)
        if 'binary_indicator' in dataset_config:
            X[bin_cols] = X[bin_cols].astype(str).applymap(lambda x: 1 if x.lower() in dataset_config['binary_indicator'] else 0).values
        else:
            X[bin_cols] = X[bin_cols].astype(str).applymap(lambda x: 1 if x.lower() in ['yes','true','1','t'] else 0).values        
        
        # if no dataset_config given, keep its original format
        # raise warning if there is not only 0/1 in the binary columns
        if (~X[bin_cols].isin([0,1])).any().any():
            raise ValueError(f'binary columns {bin_cols} contains values other than 0/1.')

    
    X = X[bin_cols + num_cols + cat_cols]

    # rename column names if is given
    if dataset_config is not None:
        data_config = dataset_config
        if 'columns' in data_config:
            new_cols = data_config['columns']
            X.columns = new_cols
            attribute_names = new_cols

        if 'bin' in data_config:
            bin_cols = data_config['bin']
        
        if 'cat' in data_config:
            cat_cols = data_config['cat']

        if 'num' in data_config:
            num_cols = data_config['num']


    # split train/val/test
    data_split_idx = None
    if dataset_config is not None:
        data_split_idx = dataset_config.get('data_split_idx', None)

    if data_split_idx is not None:
        train_idx = data_split_idx.get('train', None)
        val_idx = data_split_idx.get('val', None)
        test_idx = data_split_idx.get('test', None)

        if train_idx is None or test_idx is None:
            raise ValueError('train/test split indices must be provided together')
    
        else:
            train_dataset = X.iloc[train_idx]
            y_train = y[train_idx]
            test_dataset = X.iloc[test_idx]
            y_test = y[test_idx]
            if val_idx is not None:
                val_dataset = X.iloc[val_idx]
                y_val = y[val_idx]
            else:
                val_dataset = None
                y_val = None
    else:
        # split train/val/test
        train_dataset, test_dataset, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed, stratify=y, shuffle=True)
        val_size = int(len(y)*0.1)
        val_dataset = train_dataset.iloc[-val_size:]
        y_val = y_train[-val_size:]
        train_dataset = train_dataset.iloc[:-val_size]
        y_train = y_train[:-val_size]

    if data_cut is not None:
        np.random.shuffle(all_cols)
        sp_size=int(len(all_cols)/data_cut)
        col_splits = np.split(all_cols, range(0,len(all_cols),sp_size))[1:]
        new_col_splits = []
        for split in col_splits:
            candidate_cols = np.random.choice(np.setdiff1d(all_cols, split), int(sp_size/2), replace=False)
            new_col_splits.append(split.tolist() + candidate_cols.tolist())
        if len(col_splits) > data_cut:
            for i in range(len(col_splits[-1])):
                new_col_splits[i] += [col_splits[-1][i]]
                new_col_splits[i] = np.unique(new_col_splits[i]).tolist()
            new_col_splits = new_col_splits[:-1]

        # cut subset
        trainset_splits = np.array_split(train_dataset, data_cut)
        train_subset_list = []
        for i in range(data_cut):
            train_subset_list.append(
                (trainset_splits[i][new_col_splits[i]], y_train.loc[trainset_splits[i].index])
            )
        print('# data: {}, # feat: {}, # cate: {},  # bin: {}, # numerical: {}, pos rate: {:.2f}'.format(len(X), len(attribute_names), len(cat_cols), len(bin_cols), len(num_cols), (y==1).sum()/len(y)))
        return (X, y), train_subset_list, (val_dataset,y_val), (test_dataset, y_test), cat_cols, num_cols, bin_cols

    else:
        print('# data: {}, # feat: {}, # cate: {},  # bin: {}, # numerical: {}, pos rate: {:.2f}'.format(len(X), len(attribute_names), len(cat_cols), len(bin_cols), len(num_cols), (y==1).sum()/len(y)))
        return (X,y), (train_dataset,y_train), (val_dataset,y_val), (test_dataset, y_test), cat_cols, num_cols, bin_cols

================================================
FILE: transtab/evaluator.py
================================================
from collections import defaultdict
import os
Download .txt
gitextract_y454drjn/

├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── blog/
│   └── README.md
├── docs/
│   ├── Makefile
│   ├── make.bat
│   ├── requirements.txt
│   ├── source/
│   │   ├── about.rst
│   │   ├── conf.py
│   │   ├── data_preparation.rst
│   │   ├── example_encode.rst
│   │   ├── example_pretrain.rst
│   │   ├── example_transfer.rst
│   │   ├── fast_train.rst
│   │   ├── index.rst
│   │   ├── install.rst
│   │   ├── main_func.rst
│   │   ├── models.rst
│   │   ├── transtab.basemodel.rst
│   │   ├── transtab.build_classifier.rst
│   │   ├── transtab.build_contrastive_learner.rst
│   │   ├── transtab.build_encoder.rst
│   │   ├── transtab.build_extractor.rst
│   │   ├── transtab.classifier.rst
│   │   ├── transtab.contrastive.rst
│   │   ├── transtab.load_data.rst
│   │   ├── transtab.predict.rst
│   │   └── transtab.train.rst
│   └── sphinx-commands.txt
├── examples/
│   ├── contrastive_learning.ipynb
│   ├── fast_train.ipynb
│   ├── table_embedding.ipynb
│   ├── transfer_learning.ipynb
│   └── transfer_learning_regressor.ipynb
├── pypi_build_commands.txt
├── requirements.txt
├── setup.py
└── transtab/
    ├── __init__.py
    ├── constants.py
    ├── dataset.py
    ├── evaluator.py
    ├── modeling_transtab.py
    ├── tokenizer/
    │   ├── special_tokens_map.json
    │   ├── tokenizer_config.json
    │   └── vocab.txt
    ├── trainer.py
    ├── trainer_utils.py
    └── transtab.py
Download .txt
SYMBOL INDEX (115 symbols across 6 files)

FILE: transtab/dataset.py
  function load_data (line 37) | def load_data(dataname, dataset_config=None, encode_cat=False, data_cut=...
  function load_single_data (line 105) | def load_single_data(dataname, dataset_config=None, encode_cat=False, da...

FILE: transtab/evaluator.py
  function predict (line 11) | def predict(clf,
  function evaluate (line 72) | def evaluate(ypred, y_test, metric='auc', seed=123, bootstrap=False):
  function get_eval_metric_fn (line 101) | def get_eval_metric_fn(eval_metric):
  function acc_fn (line 110) | def acc_fn(y, p):
  function auc_fn (line 114) | def auc_fn(y, p):
  function mse_fn (line 117) | def mse_fn(y, p):
  class EarlyStopping (line 120) | class EarlyStopping:
    method __init__ (line 122) | def __init__(self, patience=7, verbose=False, delta=0, output_dir='ckp...
    method __call__ (line 148) | def __call__(self, val_loss, model):
    method save_checkpoint (line 170) | def save_checkpoint(self, val_loss, model):

FILE: transtab/modeling_transtab.py
  class TransTabWordEmbedding (line 21) | class TransTabWordEmbedding(nn.Module):
    method __init__ (line 25) | def __init__(self,
    method forward (line 38) | def forward(self, input_ids) -> Tensor:
  class TransTabNumEmbedding (line 44) | class TransTabNumEmbedding(nn.Module):
    method __init__ (line 48) | def __init__(self, hidden_dim) -> None:
    method forward (line 54) | def forward(self, num_col_emb, x_num_ts, num_mask=None) -> Tensor:
  class TransTabFeatureExtractor (line 64) | class TransTabFeatureExtractor:
    method __init__ (line 69) | def __init__(self,
    method __call__ (line 118) | def __call__(self, x, shuffle=False) -> Dict:
    method save (line 190) | def save(self, path):
    method load (line 211) | def load(self, path):
    method update (line 226) | def update(self, cat=None, num=None, bin=None):
    method _check_column_overlap (line 249) | def _check_column_overlap(self, cat_cols=None, num_cols=None, bin_cols...
    method _solve_duplicate_cols (line 262) | def _solve_duplicate_cols(self, duplicate_cols):
  class TransTabFeatureProcessor (line 275) | class TransTabFeatureProcessor(nn.Module):
    method __init__ (line 279) | def __init__(self,
    method _avg_embedding_by_mask (line 303) | def _avg_embedding_by_mask(self, embs, att_mask=None):
    method forward (line 311) | def forward(self,
  function _get_activation_fn (line 364) | def _get_activation_fn(activation):
  class TransTabTransformerLayer (line 375) | class TransTabTransformerLayer(nn.Module):
    method __init__ (line 377) | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, ...
    method _sa_block (line 409) | def _sa_block(self, x: Tensor,
    method _ff_block (line 420) | def _ff_block(self, x: Tensor) -> Tensor:
    method __setstate__ (line 427) | def __setstate__(self, state):
    method forward (line 432) | def forward(self, src, src_mask= None, src_key_padding_mask= None, is_...
  class TransTabInputEncoder (line 458) | class TransTabInputEncoder(nn.Module):
    method __init__ (line 493) | def __init__(self,
    method forward (line 504) | def forward(self, x):
    method load (line 517) | def load(self, ckpt_dir):
  class TransTabEncoder (line 529) | class TransTabEncoder(nn.Module):
    method __init__ (line 530) | def __init__(self,
    method forward (line 567) | def forward(self, embedding, attention_mask=None, **kwargs) -> Tensor:
  class TransTabLinearClassifier (line 576) | class TransTabLinearClassifier(nn.Module):
    method __init__ (line 577) | def __init__(self,
    method forward (line 587) | def forward(self, x) -> Tensor:
  class TransTabLinearRegressor (line 593) | class TransTabLinearRegressor(nn.Module):
    method __init__ (line 594) | def __init__(self,
    method forward (line 600) | def forward(self, x) -> Tensor:
  class TransTabProjectionHead (line 606) | class TransTabProjectionHead(nn.Module):
    method __init__ (line 607) | def __init__(self,
    method forward (line 613) | def forward(self, x) -> Tensor:
  class TransTabCLSToken (line 617) | class TransTabCLSToken(nn.Module):
    method __init__ (line 620) | def __init__(self, hidden_dim) -> None:
    method expand (line 626) | def expand(self, *leading_dimensions):
    method forward (line 630) | def forward(self, embedding, attention_mask=None, **kwargs) -> Tensor:
  class TransTabModel (line 638) | class TransTabModel(nn.Module):
    method __init__ (line 683) | def __init__(self,
    method forward (line 744) | def forward(self, x, y=None):
    method load (line 771) | def load(self, ckpt_dir):
    method save (line 799) | def save(self, ckpt_dir):
    method update (line 825) | def update(self, config):
    method _check_column_overlap (line 856) | def _check_column_overlap(self, cat_cols=None, num_cols=None, bin_cols...
    method _solve_duplicate_cols (line 866) | def _solve_duplicate_cols(self, duplicate_cols):
    method _adapt_to_new_num_class (line 879) | def _adapt_to_new_num_class(self, num_class):
  class TransTabClassifier (line 891) | class TransTabClassifier(TransTabModel):
    method __init__ (line 937) | def __init__(self,
    method forward (line 974) | def forward(self, x, y=None):
  class TransTabRegressor (line 1027) | class TransTabRegressor(TransTabModel):
    method __init__ (line 1073) | def __init__(self,
    method forward (line 1108) | def forward(self, x, y=None):
  class TransTabForCL (line 1157) | class TransTabForCL(TransTabModel):
    method __init__ (line 1218) | def __init__(self,
    method forward (line 1265) | def forward(self, x, y=None):
    method _build_positive_pairs (line 1321) | def _build_positive_pairs(self, x, n):
    method cos_sim (line 1336) | def cos_sim(self, a, b):
    method self_supervised_contrastive_loss (line 1353) | def self_supervised_contrastive_loss(self, features):
    method supervised_contrastive_loss (line 1390) | def supervised_contrastive_loss(self, features, labels):

FILE: transtab/trainer.py
  class Trainer (line 24) | class Trainer:
    method __init__ (line 25) | def __init__(self,
    method train (line 101) | def train(self):
    method evaluate (line 144) | def evaluate(self):
    method train_no_dataloader (line 181) | def train_no_dataloader(self,
    method save_model (line 242) | def save_model(self, output_dir=None):
    method create_optimizer (line 264) | def create_optimizer(self):
    method create_scheduler (line 280) | def create_scheduler(self, num_training_steps, optimizer):
    method get_num_train_steps (line 289) | def get_num_train_steps(self, train_set_list, num_epoch, batch_size):
    method get_warmup_steps (line 297) | def get_warmup_steps(self, num_training_steps):
    method _build_dataloader (line 306) | def _build_dataloader(self, trainset, batch_size, collator, num_worker...

FILE: transtab/trainer_utils.py
  class TrainDataset (line 30) | class TrainDataset(Dataset):
    method __init__ (line 31) | def __init__(self, trainset):
    method __len__ (line 34) | def __len__(self):
    method __getitem__ (line 37) | def __getitem__(self, index):
  class TrainCollator (line 45) | class TrainCollator:
    method __init__ (line 48) | def __init__(self,
    method save (line 63) | def save(self, path):
    method __call__ (line 66) | def __call__(self, data):
  class SupervisedTrainCollator (line 69) | class SupervisedTrainCollator(TrainCollator):
    method __init__ (line 70) | def __init__(self,
    method __call__ (line 84) | def __call__(self, data):
  class TransTabCollatorForCL (line 90) | class TransTabCollatorForCL(TrainCollator):
    method __init__ (line 93) | def __init__(self,
    method __call__ (line 113) | def __call__(self, data):
    method _build_positive_pairs (line 132) | def _build_positive_pairs(self, x, n):
    method _build_positive_pairs_single_view (line 150) | def _build_positive_pairs_single_view(self, x):
  function get_parameter_names (line 160) | def get_parameter_names(model, forbidden_layer_types):
  function random_seed (line 175) | def random_seed(seed):
  function get_scheduler (line 181) | def get_scheduler(

FILE: transtab/transtab.py
  function build_classifier (line 14) | def build_classifier(
  function build_regressor (line 98) | def build_regressor(
  function build_extractor (line 182) | def build_extractor(
  function build_encoder (line 236) | def build_encoder(
  function build_contrastive_learner (line 329) | def build_contrastive_learner(
  function train (line 456) | def train(model,
Condensed preview — 49 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (576K chars).
[
  {
    "path": ".gitignore",
    "chars": 1848,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".readthedocs.yaml",
    "chars": 1033,
    "preview": "# Read the Docs configuration file for Sphinx projects\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html f"
  },
  {
    "path": "LICENSE",
    "chars": 1316,
    "preview": "BSD 2-Clause License\n\nCopyright (c) 2022, Zifeng\nAll rights reserved.\n\nRedistribution and use in source and binary forms"
  },
  {
    "path": "README.md",
    "chars": 4777,
    "preview": "# TransTab: A flexible transferable tabular learning framework [[arxiv]](https://arxiv.org/pdf/2205.09328.pdf)\n\n\n[![PyPI"
  },
  {
    "path": "blog/README.md",
    "chars": 13589,
    "preview": "# NeurIPS'22 | How to perform transfer learning and zero-shot learning on tabular data?\n\n> This is our paper accepted by"
  },
  {
    "path": "docs/Makefile",
    "chars": 638,
    "preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the "
  },
  {
    "path": "docs/make.bat",
    "chars": 769,
    "preview": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-bu"
  },
  {
    "path": "docs/requirements.txt",
    "chars": 197,
    "preview": "sphinx-markdown-tables\nrecommonmark\nsphinx==4.2.0\nsphinx_rtd_theme==1.0.0\nreadthedocs-sphinx-search==0.1.1\nloguru\nnumpy\n"
  },
  {
    "path": "docs/source/about.rst",
    "chars": 176,
    "preview": "About Us\n========\n\nThis package was developed and maintained by Zifeng Wang (Ph.D. Student @ UIUC).\n\nPlease refer to his"
  },
  {
    "path": "docs/source/conf.py",
    "chars": 2268,
    "preview": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common op"
  },
  {
    "path": "docs/source/data_preparation.rst",
    "chars": 2248,
    "preview": "Custom Dataset\n==============\n\nHere is the best practice to build your own datasets for `transtab`.\n\n::\n\n    project\n   "
  },
  {
    "path": "docs/source/example_encode.rst",
    "chars": 1457,
    "preview": "Encode Tables\n=============\n\n*transtab* is able to take pd.DataFrame as inputs and outputs the encoded sample-level embe"
  },
  {
    "path": "docs/source/example_pretrain.rst",
    "chars": 2264,
    "preview": "Tabular Pretraining\n===================\n\nWhen encountering multiple distinct tables which may have different number of c"
  },
  {
    "path": "docs/source/example_transfer.rst",
    "chars": 1941,
    "preview": "Tabular Transfer Learning\n=========================\n\n*transtab* is able to leverage the knowledge learned from broad dat"
  },
  {
    "path": "docs/source/fast_train.rst",
    "chars": 1888,
    "preview": "Fast Train with TransTab\n=========================\n\n*transtab* is featured for accepting variable-column tables for trai"
  },
  {
    "path": "docs/source/index.rst",
    "chars": 3326,
    "preview": "Welcome to transtab documentation!\n==================================\n\n`transtab` is an easy-to-use **Python package** f"
  },
  {
    "path": "docs/source/install.rst",
    "chars": 950,
    "preview": "Installation\n============\n\n*transtab* was tested on Python 3.7+, PyTorch 1.8.0+. Please follow the Installation instruct"
  },
  {
    "path": "docs/source/main_func.rst",
    "chars": 335,
    "preview": "Main Functions\n==============\n\n.. toctree::\n    load_data<transtab.load_data>\n    build_classifier<transtab.build_classi"
  },
  {
    "path": "docs/source/models.rst",
    "chars": 145,
    "preview": "Models\n======\n\n.. toctree::\n    BaseModel<transtab.basemodel>\n    TransTabClassifier<transtab.classifier>\n    TransTabFo"
  },
  {
    "path": "docs/source/transtab.basemodel.rst",
    "chars": 148,
    "preview": "TransTabModel\n=============\n\n.. automodule:: transtab.modeling_transtab\n    :members: TransTabModel\n    :no-undoc-member"
  },
  {
    "path": "docs/source/transtab.build_classifier.rst",
    "chars": 306,
    "preview": "build_classifier\n================\n\n.. autofunction:: transtab.build_classifier\n\n.. warning::\n    If ``categorical_column"
  },
  {
    "path": "docs/source/transtab.build_contrastive_learner.rst",
    "chars": 106,
    "preview": "build_contrastive_learner\n=========================\n\n.. autofunction:: transtab.build_contrastive_learner\n"
  },
  {
    "path": "docs/source/transtab.build_encoder.rst",
    "chars": 548,
    "preview": "build_extractor\n===============\n\n.. autofunction:: transtab.build_encoder\n\nThe returned feature extractor takes pd.DataF"
  },
  {
    "path": "docs/source/transtab.build_extractor.rst",
    "chars": 839,
    "preview": "build_extractor\n===============\n\n.. autofunction:: transtab.build_extractor\n\n\nThe returned feature extractor takes pd.Da"
  },
  {
    "path": "docs/source/transtab.classifier.rst",
    "chars": 163,
    "preview": "TransTabClassifier\n==================\n\n.. autoclass:: transtab.modeling_transtab.TransTabClassifier\n    :members:\n    :n"
  },
  {
    "path": "docs/source/transtab.contrastive.rst",
    "chars": 148,
    "preview": "TransTabForCL\n=============\n\n.. autoclass:: transtab.modeling_transtab.TransTabForCL\n    :members:\n    :no-undoc-members"
  },
  {
    "path": "docs/source/transtab.load_data.rst",
    "chars": 2019,
    "preview": "load_data\n=========\n\n.. autofunction:: transtab.load_data\n\n\n*transtab* provides flexible data loading function.\nIt can b"
  },
  {
    "path": "docs/source/transtab.predict.rst",
    "chars": 52,
    "preview": "predict\n=======\n\n.. autofunction:: transtab.predict\n"
  },
  {
    "path": "docs/source/transtab.train.rst",
    "chars": 46,
    "preview": "train\n=====\n\n.. autofunction:: transtab.train\n"
  },
  {
    "path": "docs/sphinx-commands.txt",
    "chars": 52,
    "preview": "# build html files\nsphinx-build -b html source build"
  },
  {
    "path": "examples/contrastive_learning.ipynb",
    "chars": 13754,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"0c0001bb\",\n   \"metadata\": {},\n   \"outputs\":"
  },
  {
    "path": "examples/fast_train.ipynb",
    "chars": 40581,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"0bc8ef17\",\n   \"metadata\": {},\n   \"outputs\":"
  },
  {
    "path": "examples/table_embedding.ipynb",
    "chars": 26695,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"9aa34ef4\",\n   \"metadata\": {},\n   \"outputs\":"
  },
  {
    "path": "examples/transfer_learning.ipynb",
    "chars": 30431,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"134f979d\",\n   \"metadata\": {},\n   \"outputs\":"
  },
  {
    "path": "examples/transfer_learning_regressor.ipynb",
    "chars": 22717,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"739e0cff\",\n   \"metadata\": {},\n   \"outputs\":"
  },
  {
    "path": "pypi_build_commands.txt",
    "chars": 341,
    "preview": "# This is a command list for building pypi packages\npython setup.py sdist bdist_wheel\n\ntwine check dist/*\n\n# upload to p"
  },
  {
    "path": "requirements.txt",
    "chars": 85,
    "preview": "numpy\nscikit_learn\nsetuptools\ntransformers<=4.30.0\ntqdm\npandas>=1.3.0\nopenml>=0.10.0\n"
  },
  {
    "path": "setup.py",
    "chars": 1202,
    "preview": "import os\nimport setuptools\n\nthis_directory = os.path.abspath(os.path.dirname(__file__))\n\nwith open(\"README.md\", \"r\") as"
  },
  {
    "path": "transtab/__init__.py",
    "chars": 61,
    "preview": "name = 'transtab'\nversion = '0.0.6'\n\nfrom .transtab import *\n"
  },
  {
    "path": "transtab/constants.py",
    "chars": 368,
    "preview": "# Name of the files used for checkpointing\nTRAINING_ARGS_NAME = \"training_args.json\"\nTRAINER_STATE_NAME = \"trainer_state"
  },
  {
    "path": "transtab/dataset.py",
    "chars": 12562,
    "preview": "import os\nimport pdb\n\nimport pandas as pd\nimport numpy as np\nfrom sklearn.preprocessing import LabelEncoder, OrdinalEnco"
  },
  {
    "path": "transtab/evaluator.py",
    "chars": 5965,
    "preview": "from collections import defaultdict\nimport os\nimport pdb\n\nimport torch\nimport numpy as np\nfrom sklearn.metrics import ro"
  },
  {
    "path": "transtab/modeling_transtab.py",
    "chars": 54865,
    "preview": "import os, pdb\nimport math\nimport collections\nimport json\nfrom typing import Dict, Optional, Any, Union, Callable, List\n"
  },
  {
    "path": "transtab/tokenizer/special_tokens_map.json",
    "chars": 112,
    "preview": "{\"unk_token\": \"[UNK]\", \"sep_token\": \"[SEP]\", \"pad_token\": \"[PAD]\", \"cls_token\": \"[CLS]\", \"mask_token\": \"[MASK]\"}"
  },
  {
    "path": "transtab/tokenizer/tokenizer_config.json",
    "chars": 48,
    "preview": "{\"do_lower_case\": true, \"model_max_length\": 512}"
  },
  {
    "path": "transtab/tokenizer/vocab.txt",
    "chars": 228209,
    "preview": "[PAD]\n[unused0]\n[unused1]\n[unused2]\n[unused3]\n[unused4]\n[unused5]\n[unused6]\n[unused7]\n[unused8]\n[unused9]\n[unused10]\n[un"
  },
  {
    "path": "transtab/trainer.py",
    "chars": 13845,
    "preview": "import os\nimport pdb\nimport math\nimport time\nimport json\n\nimport torch\nfrom torch import nn\nfrom torch.utils.data import"
  },
  {
    "path": "transtab/trainer_utils.py",
    "chars": 7772,
    "preview": "import pdb\nimport os\nimport random\nimport math\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom torch.utils.dat"
  },
  {
    "path": "transtab/transtab.py",
    "chars": 17684,
    "preview": "import pdb\nimport os\n\nfrom transtab import constants\nfrom transtab.modeling_transtab import TransTabClassifier, TransTab"
  }
]

About this extraction

This page contains the full source code of the RyanWangZf/transtab GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 49 files (510.6 KB), approximately 177.7k tokens, and a symbol index with 115 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!