Repository: nyoki-mtl/bert-mcts-youtube
Branch: main
Commit: a12e0bdaf313
Files: 41
Total size: 52.3 KB
Directory structure:
gitextract_4mn9ams8/
├── .gitignore
├── Makefile
├── README.md
├── configs/
│ ├── .gitkeep
│ ├── mlm_base.yaml
│ └── policy_value_base.yaml
├── docker/
│ ├── Dockerfile
│ └── entrypoint.sh
├── engine/
│ ├── mcts_player.sh
│ └── policy_player.sh
├── env_name.yml
├── requirements.txt
├── setup.py
├── src/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── mlm.py
│ │ └── policy_value.py
│ ├── features/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ └── policy_value.py
│ ├── model/
│ │ ├── __init__.py
│ │ └── bert.py
│ ├── pl_modules/
│ │ ├── __init__.py
│ │ ├── mlm.py
│ │ └── policy_value.py
│ ├── player/
│ │ ├── __init__.py
│ │ ├── base_player.py
│ │ ├── mcts_player.py
│ │ ├── policy_player.py
│ │ └── usi.py
│ ├── uct/
│ │ ├── __init__.py
│ │ └── uct_node.py
│ └── utils/
│ ├── __init__.py
│ ├── hcpe.py
│ ├── misc.py
│ ├── sfen.py
│ └── shogi.py
└── tools/
├── download_and_build_lesserkai.sh
├── make_dataset.py
├── pl_to_transformers.py
├── test_engine.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Project specific folders
lightning_logs/
tf_logs/
remote_logs/
local_logs/
*_logs/
checkpoints/
dataset/
local_misc/
notebooks/data/
# LSF logfiles
lsf.*
# IDEA folders
.idea/
/work_dirs/*
/data/*
!.gitkeep
================================================
FILE: Makefile
================================================
PROJECT ?= bert-mcts
DATADIR ?= ${PWD}/data
WORKSPACE ?= /workspace/$(PROJECT)
DOCKER_IMAGE ?= ${PROJECT}:latest
SHMSIZE ?= 100G
DOCKER_OPTS := \
--name ${PROJECT} \
--rm -it \
--shm-size=${SHMSIZE} \
-v ${PWD}:${WORKSPACE} \
-v ${DATADIR}:${WORKSPACE}/data \
-v ${LOG_DIR}:${WORKSPACE}/work_dirs/logs \
-w ${WORKSPACE} \
--ipc=host \
--network=host \
--gpus all
docker-build:
docker build -f docker/Dockerfile -t ${DOCKER_IMAGE} .
docker-start-interactive: docker-build
docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash
docker-start-jupyter: docker-build
docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \
bash -c "jupyter lab --port=8888 --ip=0.0.0.0 --allow-root --no-browser"
docker-run: docker-build
docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \
bash -c "${COMMAND}"
================================================
FILE: README.md
================================================
# BERT-MCTS-YOUTUBE
YouTubeにてヨビノリたくみさんと対戦した将棋ソフトです。
自然言語モデルであるBERTとモンテカルロ木探索(MCTS)の組み合わせで出来ています。
すべてpythonで書いてあるため、探索の速度は遅いです。
BERT以外の大部分は『将棋AIで学ぶディープラーニング』を参考に書いています。
- [書籍(amazon)](https://www.amazon.co.jp/dp/B07B7JJ929)
- [github](https://github.com/TadaoYamaoka/python-dlshogi)
## 環境
### Colab
テストするだけなら[google colab](https://colab.research.google.com/drive/10KAuLlNe6FKZBp_iE2bQJPNhoY2WeACx?usp=sharing) が簡単です。
以下はローカルで試す場合。CPUだと遅いのでCUDA環境が望ましいです。
重みファイルは[ここ](https://drive.google.com/drive/folders/1N-Np2NmNLtLGS9gjnreBkYdTxrH1EHFw?usp=sharing) にアップしてあり、
たくみさんと戦った重みファイルがyoutube_version.ckpt、追加で数日間学習させた重みファイルがlatest.ckptになります。
ダウンロード先のパスはengine/***_player.sh内で指定してください。
デフォルトではwork_dirs以下にダウンロードすることを想定しています。
### Docker
cuda10.2以上のnvidia-dockerが整っているなら次のコマンドで環境に入れます。
```bash
$ make docker-start-interactive
```
### Ubuntu18.04
cuda10.2でanacondaが入っていれば次のコマンドで仮想環境を作れます。
```bash
$ conda env create -f env_name.yml
$ conda activate bert-mcts-youtube
$ python setup.py develop
```
### Windows10
未検証
## 将棋エンジンのテスト
エンジンはengineディレクトリに用意しています。これらはShogiGUIなどから呼び出すことができます。
- policy_player.shはBERTの出力する方策のみを頼りに指すモデル(弱い)
- mcts_player.shはBERTの出力をもとにMCTSで探索するモデル
## 学習
学習には互角局面集とGCTの自己対戦棋譜を用いました。
モデルはMasked Language Modelで事前学習してから、Policy Value Networkの学習という手順を踏みます。
ただし、将棋は良質な教師データが大量にあるため事前学習の効果はあまりない気がします。
### データの準備
互角局面集のダウンロード
```bash
$ cd data
$ git clone https://github.com/tttak/ShogiGokakuKyokumen.git
```
GCTの自己対戦棋譜
```bash
$ cd data
$ mkdir hcpe
```
GCTの自己対戦棋譜は開発者の加納さんが[リンク](https://drive.google.com/drive/folders/14FaqqIHRctTQIY6hScCFXWQQZ_pSU3-F)
に公開してくださっていまし。
ここからselfplay-***となっているファイルをいくつかdata/hcpe以下にダウンロードしてください。
サイズが大きいので一個でも十分な量あります。
これらを準備できたら以下のコマンドでデータセットを作ります。
```bash
$ python tools/make_dataset.py
```
### Masked Language Modelの学習
```bash
$ python tools/train.py configs/mlm_base.yaml
```
### 重みファイルの変換
Masked Language Modelのチェックポイントをtransformers形式に変換しておきます。
これによって転移学習のコードが多少書きやすくなります。
```bash
$ python tools/pl_to_transformers.py work_dirs/mlm_base/version_0/checkpoints/last.ckpt
```
### Policy Value Modelの学習
最後にこれらを使ってPolicy Valueを学習させます。
```bash
$ python tools/train.py configs/policy_value.yaml
```
================================================
FILE: configs/.gitkeep
================================================
================================================
FILE: configs/mlm_base.yaml
================================================
model_type: 'MLM'
seed: 42
dataset_dir: './data/dataset/gokaku_100'
model_dir:
train_loader:
batch_size: 64
shuffle: True
num_workers: 8
pin_memory: False
drop_last: True
val_loader:
batch_size: 64
shuffle: False
num_workers: 8
pin_memory: False
drop_last: False
train_params:
max_epochs: 5
# validationおよびcheckpointの間隔step数
val_check_interval: 3000
# 環境に応じて,適宜変更
gpus: [0]
================================================
FILE: configs/policy_value_base.yaml
================================================
model_type: 'PolicyValue'
seed: 42
dataset_dir: './data/dataset/selfplay'
model_dir: './work_dirs/mlm_base/version_0/checkpoints'
train_loader:
batch_size: 128
shuffle: True
num_workers: 0
pin_memory: True
drop_last: True
val_loader:
batch_size: 128
shuffle: False
num_workers: 0
pin_memory: True
drop_last: False
train_params:
max_epochs: 1
# validationおよびcheckpointの間隔step数
val_check_interval: 30000
limit_val_batches: 0.1
# 環境に応じて,適宜変更
gpus: [0]
================================================
FILE: docker/Dockerfile
================================================
FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
ENV DEBIAN_FRONTEND=noninteractive
ADD requirements.txt /tmp
RUN pip install -r /tmp/requirements.txt
ADD docker/entrypoint.sh /tmp
ENTRYPOINT ["bash", "/tmp/entrypoint.sh"]
================================================
FILE: docker/entrypoint.sh
================================================
python setup.py develop
exec "$@"
================================================
FILE: engine/mcts_player.sh
================================================
#!/bin/sh
python -m src.player.mcts_player --ckpt_path ./work_dirs/youtube_version.ckpt
================================================
FILE: engine/policy_player.sh
================================================
#!/bin/sh
python -m src.player.policy_player --ckpt_path ./work_dirs/youtube_version.ckpt
================================================
FILE: env_name.yml
================================================
name: bert-mcts-youtube
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- alabaster=0.7.12=py37_0
- anaconda=2020.11=py37_0
- anaconda-client=1.7.2=py37_0
- anaconda-project=0.8.4=py_0
- argh=0.26.2=py37_0
- argon2-cffi=20.1.0=py37h7b6447c_1
- asn1crypto=1.4.0=py_0
- astroid=2.4.2=py37_0
- astropy=4.0.2=py37h7b6447c_0
- async_generator=1.10=py37h28b3542_0
- atomicwrites=1.4.0=py_0
- attrs=20.3.0=pyhd3eb1b0_0
- autopep8=1.5.4=py_0
- babel=2.8.1=pyhd3eb1b0_0
- backcall=0.2.0=py_0
- backports=1.0=py_2
- backports.shutil_get_terminal_size=1.0.0=py37_2
- beautifulsoup4=4.9.3=pyhb0f4dca_0
- bitarray=1.6.1=py37h27cfd23_0
- bkcharts=0.2=py37_0
- blas=1.0=mkl
- bleach=3.2.1=py_0
- blosc=1.20.1=hd408876_0
- bokeh=2.2.3=py37_0
- boto=2.49.0=py37_0
- bottleneck=1.3.2=py37heb32a55_1
- brotlipy=0.7.0=py37h7b6447c_1000
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2020.10.14=0
- cairo=1.14.12=h8948797_3
- certifi=2020.6.20=pyhd3eb1b0_3
- cffi=1.14.3=py37he30daa8_0
- chardet=3.0.4=py37_1003
- click=7.1.2=py_0
- cloudpickle=1.6.0=py_0
- clyent=1.2.2=py37_1
- colorama=0.4.4=py_0
- contextlib2=0.6.0.post1=py_0
- cryptography=3.1.1=py37h1ba5d50_0
- cudatoolkit=10.2.89=hfd86e86_1
- curl=7.71.1=hbc83047_1
- cycler=0.10.0=py37_0
- cython=0.29.21=py37he6710b0_0
- cytoolz=0.11.0=py37h7b6447c_0
- dask=2.30.0=py_0
- dask-core=2.30.0=py_0
- dbus=1.13.18=hb2f20db_0
- decorator=4.4.2=py_0
- defusedxml=0.6.0=py_0
- diff-match-patch=20200713=py_0
- distributed=2.30.1=py37h06a4308_0
- docutils=0.16=py37_1
- entrypoints=0.3=py37_0
- et_xmlfile=1.0.1=py_1001
- expat=2.2.10=he6710b0_2
- fastcache=1.1.0=py37h7b6447c_0
- ffmpeg=4.3=hf484d3e_0
- filelock=3.0.12=py_0
- flake8=3.8.4=py_0
- flask=1.1.2=py_0
- fontconfig=2.13.0=h9420a91_0
- freetype=2.10.4=h5ab3b9f_0
- fribidi=1.0.10=h7b6447c_0
- fsspec=0.8.3=py_0
- future=0.18.2=py37_1
- get_terminal_size=1.0.0=haa9412d_0
- gevent=20.9.0=py37h7b6447c_0
- glib=2.66.1=h92f7085_0
- glob2=0.7=py_0
- gmp=6.1.2=h6c8ec71_1
- gmpy2=2.0.8=py37h10f8cd9_2
- gnutls=3.6.15=he1e5248_0
- graphite2=1.3.14=h23475e2_0
- greenlet=0.4.17=py37h7b6447c_0
- gst-plugins-base=1.14.0=hbbd80ab_1
- gstreamer=1.14.0=hb31296c_0
- h5py=2.10.0=py37h7918eee_0
- harfbuzz=2.4.0=hca77d97_1
- hdf5=1.10.4=hb1b8bf9_0
- heapdict=1.0.1=py_0
- html5lib=1.1=py_0
- icu=58.2=he6710b0_3
- idna=2.10=py_0
- imageio=2.9.0=py_0
- imagesize=1.2.0=py_0
- importlib-metadata=2.0.0=py_1
- importlib_metadata=2.0.0=1
- iniconfig=1.1.1=py_0
- intel-openmp=2020.2=254
- intervaltree=3.1.0=py_0
- ipykernel=5.3.4=py37h5ca1d4c_0
- ipython=7.19.0=py37hb070fc8_0
- ipython_genutils=0.2.0=py37_0
- ipywidgets=7.5.1=py_1
- isort=5.6.4=py_0
- itsdangerous=1.1.0=py37_0
- jbig=2.1=hdba287a_0
- jdcal=1.4.1=py_0
- jedi=0.17.1=py37_0
- jeepney=0.5.0=pyhd3eb1b0_0
- jinja2=2.11.2=py_0
- joblib=0.17.0=py_0
- jpeg=9b=h024ee3a_2
- json5=0.9.5=py_0
- jsonschema=3.2.0=py_2
- jupyter=1.0.0=py37_7
- jupyter_client=6.1.7=py_0
- jupyter_console=6.2.0=py_0
- jupyter_core=4.6.3=py37_0
- jupyterlab=2.2.6=py_0
- jupyterlab_pygments=0.1.2=py_0
- jupyterlab_server=1.2.0=py_0
- keyring=21.4.0=py37_1
- kiwisolver=1.3.0=py37h2531618_0
- krb5=1.18.2=h173b8e3_0
- lame=3.100=h7b6447c_0
- lazy-object-proxy=1.4.3=py37h7b6447c_0
- lcms2=2.11=h396b838_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libarchive=3.4.2=h62408e4_0
- libcurl=7.71.1=h20c2e04_1
- libedit=3.1.20191231=h14c3975_1
- libffi=3.3=he6710b0_2
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.0=h27cfd23_0
- liblief=0.10.1=he6710b0_0
- libllvm10=10.0.1=hbcb73fb_5
- libpng=1.6.37=hbc83047_0
- libsodium=1.0.18=h7b6447c_0
- libspatialindex=1.9.3=he6710b0_0
- libssh2=1.9.0=h1ba5d50_1
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.1.0=h2733197_1
- libtool=2.4.6=h7b6447c_1005
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.0.3=h1bed415_2
- libuv=1.40.0=h7b6447c_0
- libxcb=1.14=h7b6447c_0
- libxml2=2.9.10=hb55368b_3
- libxslt=1.1.34=hc22bd24_0
- llvmlite=0.34.0=py37h269e1b5_4
- locket=0.2.0=py37_1
- lxml=4.6.1=py37hefd8a0e_0
- lz4-c=1.9.2=heb0550a_3
- lzo=2.10=h7b6447c_2
- markupsafe=1.1.1=py37h14c3975_1
- matplotlib=3.3.2=0
- matplotlib-base=3.3.2=py37h817c723_0
- mccabe=0.6.1=py37_1
- mistune=0.8.4=py37h14c3975_1001
- mkl=2020.2=256
- mkl-service=2.3.0=py37he904b0f_0
- mkl_fft=1.2.0=py37h23d657b_0
- mkl_random=1.1.1=py37h0573a6f_0
- mock=4.0.2=py_0
- more-itertools=8.6.0=pyhd3eb1b0_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.1.0=py37_0
- msgpack-python=1.0.0=py37hfd86e86_1
- multipledispatch=0.6.0=py37_0
- nbclient=0.5.1=py_0
- nbconvert=6.0.7=py37_0
- nbformat=5.0.8=py_0
- ncurses=6.2=he6710b0_1
- nest-asyncio=1.4.2=pyhd3eb1b0_0
- nettle=3.7.2=hbbd107a_1
- networkx=2.5=py_0
- ninja=1.10.2=py37hff7bd54_0
- nltk=3.5=py_0
- nose=1.3.7=py37_1004
- notebook=6.1.4=py37_0
- numba=0.51.2=py37h04863e7_1
- numexpr=2.7.1=py37h423224d_0
- numpy=1.19.2=py37h54aff64_0
- numpy-base=1.19.2=py37hfa32c7d_0
- numpydoc=1.1.0=pyhd3eb1b0_1
- olefile=0.46=py37_0
- openh264=2.1.0=hd408876_0
- openpyxl=3.0.5=py_0
- openssl=1.1.1h=h7b6447c_0
- packaging=20.4=py_0
- pandas=1.1.3=py37he6710b0_0
- pandoc=2.11=hb0f4dca_0
- pandocfilters=1.4.3=py37h06a4308_1
- pango=1.45.3=hd140c19_0
- parso=0.7.0=py_0
- partd=1.1.0=py_0
- patchelf=0.12=he6710b0_0
- path=15.0.0=py37_0
- path.py=12.5.0=0
- pathlib2=2.3.5=py37_1
- pathtools=0.1.2=py_1
- patsy=0.5.1=py37_0
- pcre=8.44=he6710b0_0
- pep8=1.7.1=py37_0
- pexpect=4.8.0=py37_1
- pickleshare=0.7.5=py37_1001
- pillow=8.0.1=py37he98fc37_0
- pip=20.2.4=py37h06a4308_0
- pixman=0.40.0=h7b6447c_0
- pkginfo=1.6.1=py37h06a4308_0
- pluggy=0.13.1=py37_0
- ply=3.11=py37_0
- prometheus_client=0.8.0=py_0
- prompt-toolkit=3.0.8=py_0
- prompt_toolkit=3.0.8=0
- psutil=5.7.2=py37h7b6447c_0
- ptyprocess=0.6.0=py37_0
- py=1.9.0=py_0
- py-lief=0.10.1=py37h403a769_0
- pycodestyle=2.6.0=py_0
- pycosat=0.6.3=py37h7b6447c_0
- pycparser=2.20=py_2
- pycrypto=2.6.1=py37h7b6447c_10
- pycurl=7.43.0.6=py37h1ba5d50_0
- pydocstyle=5.1.1=py_0
- pyflakes=2.2.0=py_0
- pygments=2.7.2=pyhd3eb1b0_0
- pylint=2.6.0=py37_0
- pyodbc=4.0.30=py37he6710b0_0
- pyopenssl=19.1.0=py_1
- pyparsing=2.4.7=py_0
- pyqt=5.9.2=py37h05f1152_2
- pyrsistent=0.17.3=py37h7b6447c_0
- pysocks=1.7.1=py37_1
- pytables=3.6.1=py37h71ec239_0
- pytest=6.1.1=py37_0
- python=3.7.9=h7579374_0
- python-dateutil=2.8.1=py_0
- python-jsonrpc-server=0.4.0=py_0
- python-language-server=0.35.1=py_0
- python-libarchive-c=2.9=py_0
- pytorch=1.8.1=py3.7_cuda10.2_cudnn7.6.5_0
- pytz=2020.1=py_0
- pywavelets=1.1.1=py37h7b6447c_2
- pyxdg=0.27=pyhd3eb1b0_0
- pyyaml=5.3.1=py37h7b6447c_1
- pyzmq=19.0.2=py37he6710b0_1
- qdarkstyle=2.8.1=py_0
- qt=5.9.7=h5867ecd_1
- qtawesome=1.0.1=py_0
- qtconsole=4.7.7=py_0
- qtpy=1.9.0=py_0
- readline=8.0=h7b6447c_0
- regex=2020.10.15=py37h7b6447c_0
- requests=2.24.0=py_0
- ripgrep=12.1.1=0
- rope=0.18.0=py_0
- rtree=0.9.4=py37_1
- ruamel_yaml=0.15.87=py37h7b6447c_1
- scikit-image=0.17.2=py37hdf5156a_0
- scikit-learn=0.23.2=py37h0573a6f_0
- scipy=1.5.2=py37h0b6359f_0
- seaborn=0.11.0=py_0
- secretstorage=3.1.2=py37_1
- send2trash=1.5.0=py37_0
- setuptools=50.3.1=py37h06a4308_1
- simplegeneric=0.8.1=py37_2
- singledispatch=3.4.0.3=py_1001
- sip=4.19.8=py37hf484d3e_0
- six=1.15.0=py37h06a4308_0
- snowballstemmer=2.0.0=py_0
- sortedcollections=1.2.1=py_0
- sortedcontainers=2.2.2=py_0
- soupsieve=2.0.1=py_0
- sphinx=3.2.1=py_0
- sphinxcontrib=1.0=py37_1
- sphinxcontrib-applehelp=1.0.2=py_0
- sphinxcontrib-devhelp=1.0.2=py_0
- sphinxcontrib-htmlhelp=1.0.3=py_0
- sphinxcontrib-jsmath=1.0.1=py_0
- sphinxcontrib-qthelp=1.0.3=py_0
- sphinxcontrib-serializinghtml=1.1.4=py_0
- sphinxcontrib-websupport=1.2.4=py_0
- spyder=4.1.5=py37_0
- spyder-kernels=1.9.4=py37_0
- sqlalchemy=1.3.20=py37h27cfd23_0
- sqlite=3.33.0=h62c20be_0
- statsmodels=0.12.0=py37h7b6447c_0
- sympy=1.6.2=py37h06a4308_1
- tbb=2020.3=hfd86e86_0
- tblib=1.7.0=py_0
- terminado=0.9.1=py37_0
- testpath=0.4.4=py_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- tifffile=2020.10.1=py37hdd07704_2
- tk=8.6.10=hbc83047_0
- toml=0.10.1=py_0
- toolz=0.11.1=py_0
- torchvision=0.9.1=py37_cu102
- tornado=6.0.4=py37h7b6447c_1
- tqdm=4.50.2=py_0
- traitlets=5.0.5=py_0
- typed-ast=1.4.1=py37h7b6447c_0
- typing_extensions=3.7.4.3=py_0
- ujson=4.0.1=py37he6710b0_0
- unicodecsv=0.14.1=py37_0
- unixodbc=2.3.9=h7b6447c_0
- urllib3=1.25.11=py_0
- watchdog=0.10.3=py37_0
- wcwidth=0.2.5=py_0
- webencodings=0.5.1=py37_1
- werkzeug=1.0.1=py_0
- wheel=0.35.1=py_0
- widgetsnbextension=3.5.1=py37_0
- wrapt=1.11.2=py37h7b6447c_0
- wurlitzer=2.0.1=py37_0
- xlrd=1.2.0=py37_0
- xlsxwriter=1.3.7=py_0
- xlwt=1.3.0=py37_0
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- yapf=0.30.0=py_0
- zeromq=4.3.3=he6710b0_3
- zict=2.0.0=py_0
- zipp=3.4.0=pyhd3eb1b0_0
- zlib=1.2.11=h7b6447c_3
- zope=1.0=py37_1
- zope.event=4.5.0=py37_0
- zope.interface=5.1.2=py37h7b6447c_0
- zstd=1.4.5=h9ceee32_0
- pip:
- absl-py==0.12.0
- aiohttp==3.7.4.post0
- async-timeout==3.0.1
- cachetools==4.2.1
- cshogi==0.2.4
- google-auth==1.28.1
- google-auth-oauthlib==0.4.4
- grpcio==1.37.0
- markdown==3.3.4
- multidict==5.1.0
- oauthlib==3.1.0
- omegaconf==2.0.6
- protobuf==3.15.8
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pytorch-lightning==1.2.7
- requests-oauthlib==1.3.0
- rsa==4.7.2
- sacremoses==0.0.44
- tensorboard==2.4.1
- tensorboard-plugin-wit==1.8.0
- tokenizers==0.10.2
- torchmetrics==0.2.0
- transformers==4.5.1
- yarl==1.6.3
prefix: /home/charmer/.pyenv/versions/anaconda3-2020.07/envs/bert-mcts
================================================
FILE: requirements.txt
================================================
--find-links https://download.pytorch.org/whl/torch_stable.html
torch>=1.6
pytorch-lightning==1.2.7
transformers>=4.5
cshogi
omegaconf
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension
if __name__ == '__main__':
setup(
name='bert-mcts',
description='BERT-driven Shogi AI',
author='Hiroki Taniai',
author_email='charmer.popopo@gmail.com',
packages=find_packages(include=['src']),
cmdclass={'build_ext': BuildExtension},
)
================================================
FILE: src/data/__init__.py
================================================
================================================
FILE: src/data/mlm.py
================================================
import numpy as np
import torch
from torch.utils.data import Dataset
from src.utils.shogi import pieces_list
# Masked Language Model
class MLMDataset(Dataset):
def __init__(self, data):
self.seqs = data['seq']
self.mask_token_id = 32 # 32は駒が割り振られていないid
assert self.mask_token_id not in np.array(pieces_list)
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
inputs = np.array(self.seqs[idx])
labels = inputs.copy()
# 予想対象
masked_indices = np.random.random(labels.shape) < 0.15
labels[~masked_indices] = -100
# 80%はマスクトークンに
indices_replaced = (np.random.random(labels.shape) < 0.8) & masked_indices
inputs[indices_replaced] = self.mask_token_id
# 10%はランダムに置き換え
indices_random = (np.random.random(labels.shape) < 0.5) & masked_indices & ~indices_replaced
random_words = np.random.choice(pieces_list, labels.shape)
inputs[indices_random] = random_words[indices_random]
# 残り10%はそのままのものが残る
ret_dict = {'input_ids': torch.tensor(inputs, dtype=torch.long),
'labels': torch.tensor(labels, dtype=torch.long)}
return ret_dict
================================================
FILE: src/data/policy_value.py
================================================
import torch
from torch.utils.data import Dataset
class PolicyValueDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
dt = self.data[idx]
ret_dict = {'input_ids': torch.tensor(dt['seq'], dtype=torch.long),
'labels': torch.tensor(dt['label'], dtype=torch.long),
'values': torch.tensor(dt['value'], dtype=torch.float),
'result': torch.tensor(dt['result'], dtype=torch.long)}
return ret_dict
================================================
FILE: src/features/__init__.py
================================================
================================================
FILE: src/features/common.py
================================================
import cshogi
from src.utils.shogi import reverse_piece_fn
def get_seq_from_board(board):
bp, wp = board.pieces_in_hand
if board.turn == cshogi.BLACK:
seq = board.pieces + bp + wp
else:
# 駒の順番を逆にして、駒を先後反転させる。持ち駒は逆にする。
seq = reverse_piece_fn(board.pieces[::-1]).tolist() + wp + bp
return seq
================================================
FILE: src/features/policy_value.py
================================================
import cshogi
import numpy as np
from cshogi import move_drop_hand_piece, move_from, move_is_drop, move_is_promotion, move_to
from src.features.common import get_seq_from_board
from src.utils.shogi import (DOWN, DOWN_LEFT, DOWN_RIGHT, LEFT, MOVE_DIRECTION, MOVE_DIRECTION_PROMOTED, RIGHT, UP,
UP2_LEFT, UP2_RIGHT, UP_LEFT, UP_RIGHT)
board = cshogi.Board()
def get_move_label(move, color):
if not move_is_drop(move):
from_sq = move_from(move)
to_sq = move_to(move)
if color == cshogi.WHITE:
to_sq = 80 - to_sq
from_sq = 80 - from_sq
# file: 筋, rank: 段
from_file, from_rank = divmod(from_sq, 9)
to_file, to_rank = divmod(to_sq, 9)
dir_file = to_file - from_file
dir_rank = to_rank - from_rank
if dir_rank < 0 and dir_file == 0:
move_direction = UP
elif dir_rank == -2 and dir_file == -1:
move_direction = UP2_RIGHT
elif dir_rank == -2 and dir_file == 1:
move_direction = UP2_LEFT
elif dir_rank < 0 and dir_file < 0:
move_direction = UP_RIGHT
elif dir_rank < 0 and dir_file > 0:
move_direction = UP_LEFT
elif dir_rank == 0 and dir_file < 0:
move_direction = RIGHT
elif dir_rank == 0 and dir_file > 0:
move_direction = LEFT
elif dir_rank > 0 and dir_file == 0:
move_direction = DOWN
elif dir_rank > 0 and dir_file < 0:
move_direction = DOWN_RIGHT
elif dir_rank > 0 and dir_file > 0:
move_direction = DOWN_LEFT
else:
raise RuntimeError
# promote
if move_is_promotion(move):
move_direction = MOVE_DIRECTION_PROMOTED[move_direction]
else:
# 持ち駒
move_direction = len(MOVE_DIRECTION) + move_drop_hand_piece(move) - 1
to_sq = move_to(move)
if color == cshogi.WHITE:
to_sq = 80 - to_sq
# labelのmaxは27*81-1=2186
move_label = move_direction * 81 + to_sq
return move_label
def get_result(result, color):
# 引き分け
if result == 0:
return 0.5
# 手番が勝ち
elif ((result == 1) and (color == cshogi.BLACK)) or ((result == 2) and (color == cshogi.WHITE)):
return 1
else:
return 0
def get_policy_value_label(hcpe):
board.reset()
board.set_hcp(hcpe['hcp'])
seq = get_seq_from_board(board)
label = get_move_label(hcpe['bestMove16'], board.turn)
value = 1 / (1 + np.exp(-hcpe['eval'] * 0.0013226))
result = get_result(hcpe['gameResult'], board.turn)
return seq, label, value, result
def get_policy_value_label_from_moves(moves):
n = np.random.randint(4, len(moves)-1)
board.reset()
for move in moves[:n]:
board.push(move)
seq = get_seq_from_board(board)
label = get_move_label(moves[n], board.turn)
value = 0.5
result = 0.5
return seq, label, value, result
def get_moves_from_lines(line):
board.reset()
moves = []
for move_usi in line.split()[2:]:
move = board.move_from_usi(move_usi)
board.push(move)
moves.append(move)
return moves
================================================
FILE: src/model/__init__.py
================================================
================================================
FILE: src/model/bert.py
================================================
import torch.nn as nn
from transformers import BertConfig, BertForMaskedLM, BertModel
from src.utils.shogi import pieces_list
config = {
'vocab_size': len(pieces_list) + 4, # MASK_TOKEN_ID, MASK, CLS, SEP
'hidden_size': 768,
'num_hidden_layers': 12,
'num_attention_heads': 12,
'intermediate_size': 3072, # hidden_size * 4が目安
'hidden_act': 'gelu',
'hidden_dropout_prob': 0.1,
'attention_probs_dropout_prob': 0.1,
'max_position_embeddings': 512, # 95(=81(マス目)+7(先手持駒)+7(後手持駒))でいいかも
'type_vocab_size': 1, # 対の文章を入れない。つまりtoken_type_embeddingsは完全に無駄になっている。
'initializer_range': 0.02,
}
config = BertConfig.from_dict(config)
class BertMLM(nn.Module):
def __init__(self, model_dir=None):
super().__init__()
if model_dir is None:
self.bert = BertForMaskedLM(config)
else:
self.bert = BertForMaskedLM.from_pretrained(model_dir)
def forward(self, input_ids, labels):
return self.bert(input_ids=input_ids, labels=labels)
class BertPolicyValue(nn.Module):
def __init__(self, model_dir=None):
super().__init__()
if model_dir is None:
self.bert = BertModel(config)
else:
self.bert = BertModel.from_pretrained(model_dir)
self.policy_head = nn.Sequential(
nn.Linear(768, 768 * 2),
nn.Tanh(),
nn.Linear(768 * 2, 9 * 9 * 27)
)
self.value_head = nn.Sequential(
nn.Linear(768, 768 * 2),
nn.Tanh(),
nn.Linear(768 * 2, 1),
nn.Sigmoid()
)
self.loss_policy_fn = nn.CrossEntropyLoss()
self.loss_value_fn = nn.MSELoss()
def forward(self, input_ids, labels=None):
features = self.bert(input_ids=input_ids)['last_hidden_state']
policy = self.policy_head(features).mean(axis=1)
value = self.value_head(features).mean(axis=1).squeeze(1)
if labels is None:
return {'policy': policy, 'value': value}
else:
loss_policy = self.loss_policy_fn(policy, labels['labels'])
loss_value = self.loss_value_fn(value, labels['values'])
loss = loss_policy + loss_value
return {'loss_policy': loss_policy, 'loss_value': loss_value, 'loss': loss}
================================================
FILE: src/pl_modules/__init__.py
================================================
from .mlm import MLMModule, MLMDataModule
from .policy_value import PolicyValueModule, PolicyValueDataModule
def get_pl_modules(cfg):
if cfg.model_type == 'MLM':
return MLMModule(cfg), MLMDataModule(cfg)
elif cfg.model_type == 'PolicyValue':
return PolicyValueModule(cfg), PolicyValueDataModule(cfg)
else:
raise NotImplementedError
================================================
FILE: src/pl_modules/mlm.py
================================================
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import AdamW
from src.data.mlm import MLMDataset
from src.model.bert import BertMLM
class MLMModule(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.model = BertMLM(hparams['model_dir'])
def forward(self, batch):
input_ids = batch['input_ids']
labels = batch['labels']
return self.model(input_ids=input_ids, labels=labels)
def training_step(self, batch, batch_idx):
outputs = self(batch)
loss = outputs[0]
self.log('loss', loss)
return loss
def validation_step(self, batch, batch_idx):
outputs = self(batch)
loss = outputs[0].detach().cpu().numpy()
return {'loss': loss}
def validation_epoch_end(self, outputs):
val_loss = np.mean([out['loss'] for out in outputs])
self.log('steps', self.global_step)
self.log('val_loss', val_loss)
def configure_optimizers(self):
return AdamW(self.parameters(), lr=5e-5)
class MLMDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
def setup(self, stage=None):
dataset_dir = Path(self.cfg.dataset_dir)
train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True)
valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True)
self.train_dataset = MLMDataset(train_data)
self.val_dataset = MLMDataset(valid_data)
def train_dataloader(self):
return DataLoader(self.train_dataset, **self.cfg.train_loader)
def val_dataloader(self):
return DataLoader(self.val_dataset, **self.cfg.val_loader)
================================================
FILE: src/pl_modules/policy_value.py
================================================
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import AdamW
from src.data.policy_value import PolicyValueDataset
from src.model.bert import BertPolicyValue
class PolicyValueModule(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.model = BertPolicyValue(hparams['model_dir'])
def forward(self, input_ids, labels=None):
output = self.model(input_ids, labels)
return output
def training_step(self, batch, batch_idx):
input_ids = batch.pop('input_ids')
output = self(input_ids, batch)
return {'loss': output['loss']}
def validation_step(self, batch, batch_idx):
input_ids = batch.pop('input_ids')
output = self(input_ids, batch)
for k, v in output.items():
output[k] = v.detach().cpu().numpy()
return output
def validation_epoch_end(self, outputs):
val_loss = np.mean([out['loss'] for out in outputs])
val_loss_policy = np.mean([out['loss_policy'] for out in outputs])
val_loss_value = np.mean([out['loss_value'] for out in outputs])
self.log('val_loss', val_loss)
self.log('val_loss_policy', val_loss_policy)
self.log('val_loss_value', val_loss_value)
def configure_optimizers(self):
return AdamW(self.parameters(), lr=5e-5)
class PolicyValueDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
def setup(self, stage=None):
dataset_dir = Path(self.cfg.dataset_dir)
train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True)
print('Load train data')
valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True)
print('Load val data')
self.train_dataset = PolicyValueDataset(train_data)
self.val_dataset = PolicyValueDataset(valid_data)
def train_dataloader(self):
return DataLoader(self.train_dataset, **self.cfg.train_loader)
def val_dataloader(self):
return DataLoader(self.val_dataset, **self.cfg.val_loader)
================================================
FILE: src/player/__init__.py
================================================
================================================
FILE: src/player/base_player.py
================================================
import cshogi
class BasePlayer:
def __init__(self):
self.board = cshogi.Board()
def usi(self):
print('id name bert_player')
print('usiok')
def usinewgame(self):
pass
def setoption(self, option):
pass
def isready(self):
pass
def position(self, moves):
if moves[0] == 'startpos':
self.board.reset()
for move in moves[2:]:
self.board.push_usi(move)
elif moves[0] == 'sfen':
self.board.set_sfen(' '.join(moves[1:]))
# for debug
print(self.board.sfen())
def go(self):
pass
def quit(self):
pass
================================================
FILE: src/player/mcts_player.py
================================================
import time
from argparse import ArgumentParser
from pathlib import Path
import cshogi
import numpy as np
import torch
from src.features.common import get_seq_from_board
from src.features.policy_value import get_move_label
from src.pl_modules.policy_value import PolicyValueModule
from src.player.base_player import BasePlayer
from src.player.usi import usi
from src.uct.uct_node import NOT_EXPANDED, NodeHash, UCT_HASH_SIZE, UctNode
from src.utils.misc import boltzmann
class MCTSPlayer(BasePlayer):
def __init__(self, ckpt_path, playout_halt=1000, temperature=1, resign_threshold=0.01, c_puct=1):
super().__init__()
self.ckpt_path = ckpt_path
self.playout_halt = playout_halt
self.temperature = temperature
self.resign_threshold = resign_threshold
self.c_puct = c_puct
self.model = None
self.node_hash = NodeHash()
self.uct_nodes = [UctNode() for _ in range(UCT_HASH_SIZE)]
self.current_n_idx = None
self.playout_count = 0
def usi(self):
print('id name BERT-MCTS')
print('usiok')
def isready(self):
if self.model is None:
self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model
self.model.cuda()
self.model.eval()
self.node_hash.initialize()
print('readyok')
def go(self):
if self.board.is_game_over():
print('bestmove resign')
return
# 探索情報をクリア
self.playout_count = 0
# 古いハッシュを削除
self.node_hash.delete_old_hash(self.board, self.uct_nodes)
# 探索開始時刻
begin_time = time.time()
# ルートノードの展開
self.current_n_idx = self.expand_node()
# 候補手が1つの場合はその手を返す
current_node = self.uct_nodes[self.current_n_idx]
child_moves = current_node.child_moves
child_num = len(child_moves)
if child_num == 1:
print('bestmove', cshogi.move_to_usi(child_moves[0]))
return
def get_bestmove_and_print_info():
# 探索にかかった時間を求める
finish_time = time.time() - begin_time
if self.board.move_number < 5:
selected_index = np.random.choice(np.arange(child_num), p=current_node.policy)
else:
# 訪問回数最大の手を選択する
selected_index = np.argmax(current_node.child_moves_count)
# 選択したノードの訪問回数0ならポリシーの値
if current_node.child_moves_count[selected_index] == 0:
best_wp = current_node.policy[selected_index]
# それ以外なら勝率の平均を出す
else:
best_wp = current_node.child_value_sum[selected_index] / current_node.child_moves_count[selected_index]
# 閾値以下なら投了
if best_wp < self.resign_threshold:
bestmove = 'resign'
else:
bestmove = cshogi.move_to_usi(child_moves[selected_index])
# valueを評価値のスケールに変換
if best_wp == 1:
cp = 30000
elif best_wp == 0:
cp = -30000
else:
cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762)
nps = int(current_node.move_count / finish_time)
time_secs = int(finish_time * 1000)
nodes = current_node.move_count
hashfull = self.node_hash.get_usage_rate() * 100
print(f'info score cp {cp} hashfull {hashfull:.2f} time {time_secs} nodes {nodes} nps {nps} pv {bestmove}')
return bestmove
# プレイアウトを繰り返す
# 探索回数が閾値を超える、または探索が打ち切られたらループを抜ける
while self.playout_count < self.playout_halt:
self.playout_count += 1
self.uct_search(self.current_n_idx)
# 10回に1回 読みの状況を出力
if (self.playout_count+1) % 10 == 0:
get_bestmove_and_print_info()
# 探索を打ち切るか確認
if self.interruption_check() or not self.node_hash.enough_size:
break
bestmove = get_bestmove_and_print_info()
print('bestmove', bestmove)
def expand_node(self):
current_hash = self.board.zobrist_hash()
current_turn = self.board.turn
current_move_number = self.board.move_number
n_idx = self.node_hash.find_same_hash_index(current_hash, current_turn, current_move_number)
# 合流先が検知できれば、それを返す
if n_idx != UCT_HASH_SIZE:
return n_idx
# 空のインデックスを探す
n_idx = self.node_hash.search_empty_index(current_hash, current_turn, current_move_number)
# 現在のノードの初期化
current_node = self.uct_nodes[n_idx]
current_node.reset()
# 候補手の展開
current_node.child_moves = [move for move in self.board.legal_moves]
child_num = len(current_node.child_moves)
current_node.child_n_indices = [NOT_EXPANDED for _ in range(child_num)]
current_node.child_moves_count = np.zeros(child_num, dtype=np.int32)
current_node.child_value_sum = np.zeros(child_num, dtype=np.float32)
# ノードを評価
self.eval_node(n_idx)
return n_idx
def eval_node(self, n_idx):
current_node = self.uct_nodes[n_idx]
child_moves = current_node.child_moves
child_num = len(current_node.child_moves)
if child_num == 0:
# 指す手がない=負け
current_node.value = 0
current_node.evaled = True
else:
# 現在の局面における方策と価値
seq = get_seq_from_board(self.board)
input_ids = torch.tensor([seq], dtype=torch.long).cuda()
with torch.no_grad():
output = self.model(input_ids)
value = output['value'].detach().cpu().numpy()[0]
policy_logits = output['policy'].detach().cpu().numpy()[0]
# 合法手でフィルターする
legal_move_labels = []
for move in child_moves:
legal_move_labels.append(get_move_label(move, self.board.turn))
# Boltzmann
policy = boltzmann(policy_logits[legal_move_labels], self.temperature)
# ノードの値を更新
current_node.policy = policy
current_node.value = value
current_node.evaled = True
def uct_search(self, n_idx):
current_node = self.uct_nodes[n_idx]
child_moves = current_node.child_moves
child_n_indices = current_node.child_n_indices
child_num = len(child_moves)
# 詰みは負け->元ノードから見たら勝ち
if child_num == 0:
return 1
# PUCTアルゴリズムによるUCB値最大の手
next_c_idx = self.select_max_ucb_child(n_idx)
next_move = child_moves[next_c_idx]
next_n_idx = child_n_indices[next_c_idx]
# 選んだ手を着手
self.board.push(next_move)
if next_n_idx == NOT_EXPANDED:
# 選択した手に対応するコードが未展開なら展開
# ノードの展開(ノード展開処理の中でノードを評価する)
next_n_idx = self.expand_node()
child_n_indices[next_c_idx] = next_n_idx
child_node = self.uct_nodes[next_n_idx]
result = 1 - child_node.value
else:
# 展開済みなら一手深く読む
result = self.uct_search(next_n_idx)
# バックアップ
# 探索結果の反映
current_node.move_count += 1
current_node.child_value_sum[next_c_idx] += result
current_node.child_moves_count[next_c_idx] += 1
# 手を戻す
self.board.pop(next_move)
return 1 - result
def select_max_ucb_child(self, c_idx):
current_node = self.uct_nodes[c_idx]
child_wins_count = current_node.child_value_sum
child_moves_count = current_node.child_moves_count
child_num = len(child_moves_count)
# child_move_countが0の場所のqは0.5で埋める
q = np.divide(child_wins_count, child_moves_count, out=np.repeat(0.5, child_num), where=child_moves_count != 0)
u = np.sqrt(current_node.move_count) / (1 + child_moves_count)
ucb = q + self.c_puct * current_node.policy * u
return np.argmax(ucb)
# 探索を打ち切るか確認
def interruption_check(self):
child_move_count = self.uct_nodes[self.current_n_idx].child_moves_count
rest = self.playout_halt - self.playout_count
# 探索回数が最も多い手と次に多いてを求める
second, first = child_move_count[np.argpartition(child_move_count, -2)[-2:]]
# 残りの探索を全て次善手に費やしても最善手を超えられない場合は探索を打ち切る
if first - second > rest:
return True
else:
return False
def argparse():
parser = ArgumentParser()
parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt')
args, _ = parser.parse_known_args()
print('Command Line Args:')
print(args)
return args
def main(args):
ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve()
player = MCTSPlayer(ckpt_path)
usi(player)
if __name__ == '__main__':
args = argparse()
main(args)
================================================
FILE: src/player/policy_player.py
================================================
from argparse import ArgumentParser
from pathlib import Path
import numpy as np
import cshogi
import torch
import torch.nn.functional as F
from src.features.common import get_seq_from_board
from src.features.policy_value import get_move_label
from src.pl_modules.policy_value import PolicyValueModule
from src.player.base_player import BasePlayer
from src.player.usi import usi
from src.utils.misc import greedy
class PolicyPlayer(BasePlayer):
def __init__(self, ckpt_path):
super().__init__()
self.ckpt_path = ckpt_path
self.model = None
def usi(self):
print('id name BERT-Policy')
print('usiok')
def isready(self):
if self.model is None:
self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model
self.model.cuda()
self.model.eval()
print('readyok')
def go(self):
if self.board.is_game_over():
print('bestmove resign')
return
seq = get_seq_from_board(self.board)
input_ids = torch.tensor([seq], dtype=torch.long).cuda()
with torch.no_grad():
output = self.model(input_ids)
policy = F.softmax(output['policy'], dim=1).detach().cpu().numpy()[0]
# 全ての合法手について
legal_moves = []
legal_policy = []
for move in self.board.legal_moves:
# ラベルに変換
label = get_move_label(move, self.board.turn)
# 合法手とその指し手の確率(logits)を格納
legal_moves.append(move)
legal_policy.append(policy[label])
selected_index = greedy(legal_policy)
bestmove = cshogi.move_to_usi(legal_moves[selected_index])
best_wp = legal_policy[selected_index]
# valueを評価値のスケールに変換
if best_wp == 1:
cp = 30000
elif best_wp == 0:
cp = -30000
else:
cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762)
print(f'info score cp {cp} pv {bestmove}')
print('bestmove', bestmove)
def argparse():
parser = ArgumentParser()
parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt')
args, _ = parser.parse_known_args()
print('Command Line Args:')
print(args)
return args
def main(args):
ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve()
player = PolicyPlayer(ckpt_path)
usi(player)
if __name__ == '__main__':
args = argparse()
main(args)
================================================
FILE: src/player/usi.py
================================================
from src.player.base_player import BasePlayer
def usi(player: BasePlayer):
while True:
cmd_line = input()
cmd = cmd_line.split(' ', 1)
cmd = [c.rstrip() for c in cmd]
if cmd[0] == 'usi':
player.usi()
elif cmd[0] == 'setoption':
option = cmd[1].split(' ')
player.setoption(option)
elif cmd[0] == 'isready':
player.isready()
elif cmd[0] == 'usinewgame':
player.usinewgame()
elif cmd[0] == 'position':
moves = cmd[1].split(' ')
player.position(moves)
elif cmd[0] == 'go':
player.go()
elif cmd[0] == 'quit':
player.quit()
break
================================================
FILE: src/uct/__init__.py
================================================
================================================
FILE: src/uct/uct_node.py
================================================
# ノードの上限値
UCT_HASH_SIZE = 4096
# 未展開のノードのインデックス
NOT_EXPANDED = -1
# ゾブリストハッシュ値をUCT_HASH_SIZEに圧縮
def hash_to_index(zhash):
return ((zhash & 0xffffffff) ^ ((zhash >> 32) & 0xffffffff)) & (UCT_HASH_SIZE - 1)
class NodeHashEntry:
def __init__(self):
self.hash = 0 # ゾブリストハッシュの値
self.color = 0 # 手番
self.moves = 0 # ゲーム開始からの手数
self.flag = False # 使用中か識別するフラグ
def reset(self):
self.hash = 0
self.color = 0
self.moves = 0
self.flag = False
class NodeHash:
def __init__(self):
self.used = 0
self.enough_size = True
self.node_hash = None
def initialize(self):
self.used = 0
self.enough_size = True
if self.node_hash is None:
self.node_hash = [NodeHashEntry() for _ in range(UCT_HASH_SIZE)]
else:
for i in range(UCT_HASH_SIZE):
self.node_hash[i].reset()
# 未使用のインデックスを探して返す
def search_empty_index(self, zhash, color, moves):
key = hash_to_index(zhash)
i = key
while True:
if not self.node_hash[i].flag:
self.node_hash[i].hash = zhash
self.node_hash[i].color = color
self.node_hash[i].moves = moves
self.node_hash[i].flag = True
self.used += 1
if self.get_usage_rate() > 0.9:
self.enough_size = False
return i
i += 1
if i >= UCT_HASH_SIZE:
i = 0
# もとに戻ってくる
if i == key:
return UCT_HASH_SIZE
# ハッシュ値に対応するインデックスを返す
def find_same_hash_index(self, zhash, color, moves):
key = hash_to_index(zhash)
i = key
while True:
# もろもろの属性があっていたらiを返す
if self.node_hash[i].flag and self.node_hash[i].hash == zhash and self.node_hash[i].color == color and \
self.node_hash[i].moves == moves:
return i
else:
return UCT_HASH_SIZE
# 使用中のノードを残す
def save_used_hash(self, board, uct_nodes, n_idx):
self.node_hash[n_idx].flag = True
self.used += 1
current_node = uct_nodes[n_idx]
child_n_indices = current_node.child_n_indices
child_moves = current_node.child_moves
child_num = len(child_moves)
for i in range(child_num):
if child_n_indices[i] != NOT_EXPANDED and not self.node_hash[child_n_indices[i]].flag:
board.push(child_moves[i])
self.save_used_hash(board, uct_nodes, child_n_indices[i])
board.pop(child_moves[i])
# 古いハッシュを削除
def delete_old_hash(self, board, uct_node):
# 現在の局面をルートとする局面以外を削除する
n_idx = self.find_same_hash_index(board.zobrist_hash(), board.turn, board.move_number)
self.used = 0
for i in range(UCT_HASH_SIZE):
self.node_hash[i].reset()
if n_idx != UCT_HASH_SIZE:
self.save_used_hash(board, uct_node, n_idx)
self.enough_size = True
def get_usage_rate(self):
return self.used / UCT_HASH_SIZE
class UctNode:
def __init__(self):
self.evaled = False # 評価済みフラグ
self.move_count = 0 # ノードの訪問回数
self.value = 0 # ノードの価値ネットワークの評価(予測勝率)
self.policy = None # 正規化した方策ネットワークの出力(子ノード分の長さを持つ)
self.child_moves = None # 子ノードの指し手
self.child_n_indices = None # 子ノードのインデックス
self.child_moves_count = None # 子ノードの訪問回数 (UCB用)
self.child_value_sum = None # 子ノードのvalueの合計 (UCB用)
def reset(self):
self.evaled = False
self.move_count = 0
self.value = 0
self.policy = None
self.child_moves = None
self.child_n_indices = None
self.child_moves_count = None
self.child_value_sum = None
================================================
FILE: src/utils/__init__.py
================================================
================================================
FILE: src/utils/hcpe.py
================================================
import cshogi
import numpy as np
from tqdm import tqdm
from src.features.policy_value import get_policy_value_label
def get_data_from_hcpe(hcpes):
data = []
for hcpe in tqdm(hcpes):
data.append(get_policy_value_label(hcpe))
data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')])
return data
def load_hcpes(hcpe_paths):
hcpe_list = []
for hcpe_path in hcpe_paths:
hcpe_list.append(np.fromfile(hcpe_path, dtype=cshogi.HuffmanCodedPosAndEval))
hcpes = np.concatenate(hcpe_list)
return hcpes
================================================
FILE: src/utils/misc.py
================================================
import numpy as np
def greedy(logits):
return np.asarray(logits).argmax()
def boltzmann(logits, temperature):
logits /= temperature
logits -= logits.max()
probabilities = np.exp(logits)
probabilities /= probabilities.sum()
return probabilities
================================================
FILE: src/utils/sfen.py
================================================
from pathlib import Path
import numpy as np
from src.features.policy_value import get_moves_from_lines, get_policy_value_label_from_moves
def get_gokaku_sfen_paths(base_dir: Path, max_num_of_moves=40):
assert base_dir.name == 'ShogiGokakuKyokumen'
sfen_paths = []
# TODO: 正規表現にする
for sfen_path in sorted(list(base_dir.glob('*/*/*.sfen')) + list(base_dir.glob('*/*/*/*.sfen'))):
# 互角で40手以下の棋譜
if '互角' in sfen_path.stem and (int(sfen_path.stem.split('手目')[0][-3:]) <= max_num_of_moves):
sfen_paths.append(sfen_path)
return sfen_paths
def get_data_from_sfen(sfens):
# sfen形式のデータは棋譜の途中を使って学習データを作成する
data = []
for sfen in sfens:
moves = get_moves_from_lines(sfen)
# 手数が6手以上のものを使う
if len(moves) < 6:
continue
# 30手なら6局面を抽出
for _ in range(max(1, len(moves) // 5)):
dt = get_policy_value_label_from_moves(moves)
data.append(dt)
data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')])
return data
def load_sfen(sfen_path):
with open(sfen_path) as f:
sfens = [l.rstrip() for l in f.readlines()]
return sfens
def load_sfens(sfen_paths):
sfens = []
for kifu_path in sfen_paths:
sfens.extend(load_sfen(kifu_path))
return sfens
================================================
FILE: src/utils/shogi.py
================================================
import numpy as np
# 移動の定数
MOVE_DIRECTION = [
UP, UP_LEFT, UP_RIGHT, LEFT, RIGHT, DOWN, DOWN_LEFT, DOWN_RIGHT, UP2_LEFT, UP2_RIGHT,
UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE,
DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE
] = range(20)
# 成り変換テーブル
MOVE_DIRECTION_PROMOTED = [
UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE,
DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE
]
# 指し手を表すラベルの数
MOVE_DIRECTION_LABEL_NUM = len(MOVE_DIRECTION) + 7 # 7は持ち駒の種類
# 先手駒と後手駒を逆に変換する辞書(reverse_piece_fn)の用意
bw_dict = {k: k + 16 for k in range(1, 15)} # 1~14に自駒, 17~30に敵駒
wb_dict = {v: k for k, v in bw_dict.items()}
pieces_dict = {**bw_dict, **wb_dict, 0: 0} # 0は空きマス
pieces_list = list(pieces_dict.keys())
reverse_piece_fn = np.vectorize(pieces_dict.get)
================================================
FILE: tools/download_and_build_lesserkai.sh
================================================
wget http://shogidokoro.starfree.jp/download/LesserkaiSrc.zip -P ./work_dirs
unzip ./work_dirs/LesserkaiSrc.zip -d ./work_dirs/
rm ./work_dirs/LesserkaiSrc.zip
cd ./work_dirs/LesserkaiSrc/Lesserkai
make
================================================
FILE: tools/make_dataset.py
================================================
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
from src.utils.hcpe import get_data_from_hcpe, load_hcpes
from src.utils.sfen import get_data_from_sfen, get_gokaku_sfen_paths, load_sfens
def main():
gokaku_dir = Path('./data/ShogiGokakuKyokumen/')
hcpe_dir = Path('./data/hcpe/')
dataset_base_dir = Path('./data/dataset/')
dataset_base_dir.mkdir(exist_ok=True)
for num_of_moves in [40, 100]:
sfen_paths = get_gokaku_sfen_paths(gokaku_dir, num_of_moves)
sfens = load_sfens(sfen_paths)
data = get_data_from_sfen(sfens)
train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)
dataset_dir = dataset_base_dir / f'gokaku_{num_of_moves:03d}'
dataset_dir.mkdir()
np.save(dataset_dir / 'train.npy', train_data)
np.save(dataset_dir / 'val.npy', valid_data)
# selfplayの棋譜
hcpe_paths = sorted(hcpe_dir.glob('selfplay-*'))
hcpes = load_hcpes(hcpe_paths)
data = get_data_from_hcpe(hcpes)
train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)
dataset_dir = dataset_base_dir / 'selfplay'
dataset_dir.mkdir()
np.save(dataset_dir / 'train.npy', train_data)
np.save(dataset_dir / 'val.npy', valid_data)
if __name__ == '__main__':
main()
================================================
FILE: tools/pl_to_transformers.py
================================================
from argparse import ArgumentParser
from pathlib import Path
import torch
from src.model.bert import config
def argparse():
parser = ArgumentParser(description='Convert pl checkpoint to transformers format')
parser.add_argument('ckpt_path', type=str)
args, _ = parser.parse_known_args()
return args
def main(args):
ckpt_path = Path(args.ckpt_path)
ckpt_dir = ckpt_path.parent
ckpt = torch.load(ckpt_path)
state_dict = ckpt['state_dict']
state_dict = {'.'.join(k.split('.')[2:]): v for k, v in state_dict.items()}
# 同一ディレクトリにpytorch_model.binとconfig.jsonが必要
state_dict_path = ckpt_dir / f'pytorch_model.bin'
torch.save(state_dict, state_dict_path)
config.to_json_file(ckpt_dir / 'config.json')
if __name__ == '__main__':
args = argparse()
main(args)
================================================
FILE: tools/test_engine.py
================================================
from argparse import ArgumentParser
from pathlib import Path
from cshogi import cli
def parse_args():
parser = ArgumentParser()
parser.add_argument('engine_path')
return parser.parse_args()
def main(args):
benchmark = str(Path('./work_dirs/LesserkaiSrc/Lesserkai/Lesserkai').resolve())
engine = str((Path(args.engine_path)).resolve())
print('benchmark: ', benchmark)
print('engine: ', engine)
cli.main(benchmark, engine)
if __name__ == '__main__':
args = parse_args()
main(args)
================================================
FILE: tools/train.py
================================================
from argparse import ArgumentParser
from pathlib import Path
import pytorch_lightning as pl
from omegaconf import OmegaConf
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from src.pl_modules import get_pl_modules
def argparse():
parser = ArgumentParser()
parser.add_argument('cfg', type=str)
parser.add_argument('--log_dir', type=str, default='./work_dirs')
parser.add_argument('--ckpt_path', type=str, default='')
parser = pl.Trainer.add_argparse_args(parser)
args, _ = parser.parse_known_args()
return args
def main(args):
cfg = OmegaConf.load(args.cfg)
pl.seed_everything(cfg.seed)
# configのtrain_paramsをargsに反映
for k, v in cfg.train_params.items():
setattr(args, k, v)
# Disable default checkpoint callback
args.checkpoint_callback = False
trainer = pl.Trainer.from_argparse_args(args)
trainer.logger = pl_loggers.TensorBoardLogger(save_dir=args.log_dir, name=Path(args.cfg).stem, default_hp_metric=False)
trainer.callbacks.append(ModelCheckpoint(filename='{step:07d}-{val_loss:.2f}', monitor='val_loss',
save_top_k=1, save_last=True))
model, data = get_pl_modules(cfg)
if args.ckpt_path:
model = model.load_from_checkpoint(args.ckpt_path)
trainer.fit(model, datamodule=data)
if __name__ == '__main__':
args = argparse()
main(args)
gitextract_4mn9ams8/
├── .gitignore
├── Makefile
├── README.md
├── configs/
│ ├── .gitkeep
│ ├── mlm_base.yaml
│ └── policy_value_base.yaml
├── docker/
│ ├── Dockerfile
│ └── entrypoint.sh
├── engine/
│ ├── mcts_player.sh
│ └── policy_player.sh
├── env_name.yml
├── requirements.txt
├── setup.py
├── src/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── mlm.py
│ │ └── policy_value.py
│ ├── features/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ └── policy_value.py
│ ├── model/
│ │ ├── __init__.py
│ │ └── bert.py
│ ├── pl_modules/
│ │ ├── __init__.py
│ │ ├── mlm.py
│ │ └── policy_value.py
│ ├── player/
│ │ ├── __init__.py
│ │ ├── base_player.py
│ │ ├── mcts_player.py
│ │ ├── policy_player.py
│ │ └── usi.py
│ ├── uct/
│ │ ├── __init__.py
│ │ └── uct_node.py
│ └── utils/
│ ├── __init__.py
│ ├── hcpe.py
│ ├── misc.py
│ ├── sfen.py
│ └── shogi.py
└── tools/
├── download_and_build_lesserkai.sh
├── make_dataset.py
├── pl_to_transformers.py
├── test_engine.py
└── train.py
SYMBOL INDEX (104 symbols across 20 files)
FILE: src/data/mlm.py
class MLMDataset (line 9) | class MLMDataset(Dataset):
method __init__ (line 10) | def __init__(self, data):
method __len__ (line 15) | def __len__(self):
method __getitem__ (line 18) | def __getitem__(self, idx):
FILE: src/data/policy_value.py
class PolicyValueDataset (line 5) | class PolicyValueDataset(Dataset):
method __init__ (line 6) | def __init__(self, data):
method __len__ (line 9) | def __len__(self):
method __getitem__ (line 12) | def __getitem__(self, idx):
FILE: src/features/common.py
function get_seq_from_board (line 6) | def get_seq_from_board(board):
FILE: src/features/policy_value.py
function get_move_label (line 12) | def get_move_label(move, color):
function get_result (line 64) | def get_result(result, color):
function get_policy_value_label (line 75) | def get_policy_value_label(hcpe):
function get_policy_value_label_from_moves (line 87) | def get_policy_value_label_from_moves(moves):
function get_moves_from_lines (line 99) | def get_moves_from_lines(line):
FILE: src/model/bert.py
class BertMLM (line 22) | class BertMLM(nn.Module):
method __init__ (line 23) | def __init__(self, model_dir=None):
method forward (line 30) | def forward(self, input_ids, labels):
class BertPolicyValue (line 34) | class BertPolicyValue(nn.Module):
method __init__ (line 35) | def __init__(self, model_dir=None):
method forward (line 58) | def forward(self, input_ids, labels=None):
FILE: src/pl_modules/__init__.py
function get_pl_modules (line 5) | def get_pl_modules(cfg):
FILE: src/pl_modules/mlm.py
class MLMModule (line 12) | class MLMModule(pl.LightningModule):
method __init__ (line 14) | def __init__(self, hparams):
method forward (line 19) | def forward(self, batch):
method training_step (line 24) | def training_step(self, batch, batch_idx):
method validation_step (line 30) | def validation_step(self, batch, batch_idx):
method validation_epoch_end (line 35) | def validation_epoch_end(self, outputs):
method configure_optimizers (line 40) | def configure_optimizers(self):
class MLMDataModule (line 44) | class MLMDataModule(pl.LightningDataModule):
method __init__ (line 45) | def __init__(self, cfg):
method setup (line 49) | def setup(self, stage=None):
method train_dataloader (line 56) | def train_dataloader(self):
method val_dataloader (line 59) | def val_dataloader(self):
FILE: src/pl_modules/policy_value.py
class PolicyValueModule (line 12) | class PolicyValueModule(pl.LightningModule):
method __init__ (line 13) | def __init__(self, hparams):
method forward (line 18) | def forward(self, input_ids, labels=None):
method training_step (line 22) | def training_step(self, batch, batch_idx):
method validation_step (line 27) | def validation_step(self, batch, batch_idx):
method validation_epoch_end (line 34) | def validation_epoch_end(self, outputs):
method configure_optimizers (line 42) | def configure_optimizers(self):
class PolicyValueDataModule (line 46) | class PolicyValueDataModule(pl.LightningDataModule):
method __init__ (line 47) | def __init__(self, cfg):
method setup (line 51) | def setup(self, stage=None):
method train_dataloader (line 60) | def train_dataloader(self):
method val_dataloader (line 63) | def val_dataloader(self):
FILE: src/player/base_player.py
class BasePlayer (line 4) | class BasePlayer:
method __init__ (line 5) | def __init__(self):
method usi (line 8) | def usi(self):
method usinewgame (line 12) | def usinewgame(self):
method setoption (line 15) | def setoption(self, option):
method isready (line 18) | def isready(self):
method position (line 21) | def position(self, moves):
method go (line 31) | def go(self):
method quit (line 34) | def quit(self):
FILE: src/player/mcts_player.py
class MCTSPlayer (line 18) | class MCTSPlayer(BasePlayer):
method __init__ (line 19) | def __init__(self, ckpt_path, playout_halt=1000, temperature=1, resign...
method usi (line 33) | def usi(self):
method isready (line 37) | def isready(self):
method go (line 45) | def go(self):
method expand_node (line 123) | def expand_node(self):
method eval_node (line 152) | def eval_node(self, n_idx):
method uct_search (line 182) | def uct_search(self, n_idx):
method select_max_ucb_child (line 221) | def select_max_ucb_child(self, c_idx):
method interruption_check (line 235) | def interruption_check(self):
function argparse (line 249) | def argparse():
function main (line 258) | def main(args):
FILE: src/player/policy_player.py
class PolicyPlayer (line 17) | class PolicyPlayer(BasePlayer):
method __init__ (line 18) | def __init__(self, ckpt_path):
method usi (line 23) | def usi(self):
method isready (line 27) | def isready(self):
method go (line 34) | def go(self):
function argparse (line 72) | def argparse():
function main (line 81) | def main(args):
FILE: src/player/usi.py
function usi (line 4) | def usi(player: BasePlayer):
FILE: src/uct/uct_node.py
function hash_to_index (line 8) | def hash_to_index(zhash):
class NodeHashEntry (line 12) | class NodeHashEntry:
method __init__ (line 13) | def __init__(self):
method reset (line 19) | def reset(self):
class NodeHash (line 26) | class NodeHash:
method __init__ (line 27) | def __init__(self):
method initialize (line 32) | def initialize(self):
method search_empty_index (line 43) | def search_empty_index(self, zhash, color, moves):
method find_same_hash_index (line 65) | def find_same_hash_index(self, zhash, color, moves):
method save_used_hash (line 78) | def save_used_hash(self, board, uct_nodes, n_idx):
method delete_old_hash (line 93) | def delete_old_hash(self, board, uct_node):
method get_usage_rate (line 106) | def get_usage_rate(self):
class UctNode (line 110) | class UctNode:
method __init__ (line 111) | def __init__(self):
method reset (line 121) | def reset(self):
FILE: src/utils/hcpe.py
function get_data_from_hcpe (line 8) | def get_data_from_hcpe(hcpes):
function load_hcpes (line 17) | def load_hcpes(hcpe_paths):
FILE: src/utils/misc.py
function greedy (line 4) | def greedy(logits):
function boltzmann (line 8) | def boltzmann(logits, temperature):
FILE: src/utils/sfen.py
function get_gokaku_sfen_paths (line 8) | def get_gokaku_sfen_paths(base_dir: Path, max_num_of_moves=40):
function get_data_from_sfen (line 19) | def get_data_from_sfen(sfens):
function load_sfen (line 36) | def load_sfen(sfen_path):
function load_sfens (line 42) | def load_sfens(sfen_paths):
FILE: tools/make_dataset.py
function main (line 10) | def main():
FILE: tools/pl_to_transformers.py
function argparse (line 9) | def argparse():
function main (line 16) | def main(args):
FILE: tools/test_engine.py
function parse_args (line 7) | def parse_args():
function main (line 13) | def main(args):
FILE: tools/train.py
function argparse (line 12) | def argparse():
function main (line 22) | def main(args):
Condensed preview — 41 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (61K chars).
[
{
"path": ".gitignore",
"chars": 2012,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "Makefile",
"chars": 806,
"preview": "PROJECT ?= bert-mcts\nDATADIR ?= ${PWD}/data\nWORKSPACE ?= /workspace/$(PROJECT)\nDOCKER_IMAGE ?= ${PROJECT}:latest\n\nSHMSIZ"
},
{
"path": "README.md",
"chars": 2175,
"preview": "# BERT-MCTS-YOUTUBE\n\nYouTubeにてヨビノリたくみさんと対戦した将棋ソフトです。\n自然言語モデルであるBERTとモンテカルロ木探索(MCTS)の組み合わせで出来ています。 \nすべてpythonで書いてあるため、探索"
},
{
"path": "configs/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "configs/mlm_base.yaml",
"chars": 407,
"preview": "model_type: 'MLM'\n\nseed: 42\ndataset_dir: './data/dataset/gokaku_100'\nmodel_dir:\n\ntrain_loader:\n batch_size: 64\n shuffl"
},
{
"path": "configs/policy_value_base.yaml",
"chars": 485,
"preview": "model_type: 'PolicyValue'\n\nseed: 42\ndataset_dir: './data/dataset/selfplay'\nmodel_dir: './work_dirs/mlm_base/version_0/ch"
},
{
"path": "docker/Dockerfile",
"chars": 226,
"preview": "FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime\nENV DEBIAN_FRONTEND=noninteractive\nADD requirements.txt /tmp\nRUN pip "
},
{
"path": "docker/entrypoint.sh",
"chars": 34,
"preview": "python setup.py develop\nexec \"$@\"\n"
},
{
"path": "engine/mcts_player.sh",
"chars": 88,
"preview": "#!/bin/sh\npython -m src.player.mcts_player --ckpt_path ./work_dirs/youtube_version.ckpt\n"
},
{
"path": "engine/policy_player.sh",
"chars": 90,
"preview": "#!/bin/sh\npython -m src.player.policy_player --ckpt_path ./work_dirs/youtube_version.ckpt\n"
},
{
"path": "env_name.yml",
"chars": 10251,
"preview": "name: bert-mcts-youtube\nchannels:\n - pytorch\n - defaults\ndependencies:\n - _libgcc_mutex=0.1=main\n - alabaster=0.7.12"
},
{
"path": "requirements.txt",
"chars": 134,
"preview": "--find-links https://download.pytorch.org/whl/torch_stable.html\ntorch>=1.6\npytorch-lightning==1.2.7\ntransformers>=4.5\ncs"
},
{
"path": "setup.py",
"chars": 390,
"preview": "from setuptools import setup, find_packages\nfrom torch.utils.cpp_extension import BuildExtension\n\nif __name__ == '__main"
},
{
"path": "src/data/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "src/data/mlm.py",
"chars": 1269,
"preview": "import numpy as np\r\nimport torch\r\nfrom torch.utils.data import Dataset\r\n\r\nfrom src.utils.shogi import pieces_list\r\n\r\n\r\n#"
},
{
"path": "src/data/policy_value.py",
"chars": 604,
"preview": "import torch\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass PolicyValueDataset(Dataset):\r\n def __init__(self, data)"
},
{
"path": "src/features/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "src/features/common.py",
"chars": 334,
"preview": "import cshogi\n\nfrom src.utils.shogi import reverse_piece_fn\n\n\ndef get_seq_from_board(board):\n bp, wp = board.pieces_i"
},
{
"path": "src/features/policy_value.py",
"chars": 3308,
"preview": "import cshogi\r\nimport numpy as np\r\nfrom cshogi import move_drop_hand_piece, move_from, move_is_drop, move_is_promotion, "
},
{
"path": "src/model/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "src/model/bert.py",
"chars": 2378,
"preview": "import torch.nn as nn\r\nfrom transformers import BertConfig, BertForMaskedLM, BertModel\r\n\r\nfrom src.utils.shogi import pi"
},
{
"path": "src/pl_modules/__init__.py",
"chars": 370,
"preview": "from .mlm import MLMModule, MLMDataModule\nfrom .policy_value import PolicyValueModule, PolicyValueDataModule\n\n\ndef get_p"
},
{
"path": "src/pl_modules/mlm.py",
"chars": 1817,
"preview": "from pathlib import Path\n\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoader\nfrom"
},
{
"path": "src/pl_modules/policy_value.py",
"chars": 2203,
"preview": "from pathlib import Path\n\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoader\nfrom"
},
{
"path": "src/player/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "src/player/base_player.py",
"chars": 678,
"preview": "import cshogi\n\n\nclass BasePlayer:\n def __init__(self):\n self.board = cshogi.Board()\n\n def usi(self):\n "
},
{
"path": "src/player/mcts_player.py",
"chars": 8873,
"preview": "import time\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nimport cshogi\nimport numpy as np\nimport torch\n"
},
{
"path": "src/player/policy_player.py",
"chars": 2485,
"preview": "from argparse import ArgumentParser\nfrom pathlib import Path\nimport numpy as np\n\nimport cshogi\nimport torch\nimport torch"
},
{
"path": "src/player/usi.py",
"chars": 731,
"preview": "from src.player.base_player import BasePlayer\n\n\ndef usi(player: BasePlayer):\n while True:\n cmd_line = input()\n"
},
{
"path": "src/uct/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "src/uct/uct_node.py",
"chars": 3881,
"preview": "# ノードの上限値\nUCT_HASH_SIZE = 4096\n# 未展開のノードのインデックス\nNOT_EXPANDED = -1\n\n\n# ゾブリストハッシュ値をUCT_HASH_SIZEに圧縮\ndef hash_to_index(zhas"
},
{
"path": "src/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/utils/hcpe.py",
"chars": 583,
"preview": "import cshogi\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom src.features.policy_value import get_policy_value_label\n\n\nd"
},
{
"path": "src/utils/misc.py",
"chars": 273,
"preview": "import numpy as np\n\n\ndef greedy(logits):\n return np.asarray(logits).argmax()\n\n\ndef boltzmann(logits, temperature):\n "
},
{
"path": "src/utils/sfen.py",
"chars": 1346,
"preview": "from pathlib import Path\n\nimport numpy as np\n\nfrom src.features.policy_value import get_moves_from_lines, get_policy_val"
},
{
"path": "src/utils/shogi.py",
"chars": 933,
"preview": "import numpy as np\r\n\r\n# 移動の定数\r\nMOVE_DIRECTION = [\r\n UP, UP_LEFT, UP_RIGHT, LEFT, RIGHT, DOWN, DOWN_LEFT, DOWN_RIGHT, "
},
{
"path": "tools/download_and_build_lesserkai.sh",
"chars": 202,
"preview": "wget http://shogidokoro.starfree.jp/download/LesserkaiSrc.zip -P ./work_dirs\nunzip ./work_dirs/LesserkaiSrc.zip -d ./wor"
},
{
"path": "tools/make_dataset.py",
"chars": 1354,
"preview": "from pathlib import Path\n\nimport numpy as np\nfrom sklearn.model_selection import train_test_split\n\nfrom src.utils.hcpe i"
},
{
"path": "tools/pl_to_transformers.py",
"chars": 819,
"preview": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nimport torch\n\nfrom src.model.bert import config\n\n\ndef argp"
},
{
"path": "tools/test_engine.py",
"chars": 528,
"preview": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nfrom cshogi import cli\n\n\ndef parse_args():\n parser = Ar"
},
{
"path": "tools/train.py",
"chars": 1453,
"preview": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nimport pytorch_lightning as pl\nfrom omegaconf import Omega"
}
]
About this extraction
This page contains the full source code of the nyoki-mtl/bert-mcts-youtube GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 41 files (52.3 KB), approximately 19.2k tokens, and a symbol index with 104 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.