Full Code of nyoki-mtl/bert-mcts-youtube for AI

main a12e0bdaf313 cached
41 files
52.3 KB
19.2k tokens
104 symbols
1 requests
Download .txt
Repository: nyoki-mtl/bert-mcts-youtube
Branch: main
Commit: a12e0bdaf313
Files: 41
Total size: 52.3 KB

Directory structure:
gitextract_4mn9ams8/

├── .gitignore
├── Makefile
├── README.md
├── configs/
│   ├── .gitkeep
│   ├── mlm_base.yaml
│   └── policy_value_base.yaml
├── docker/
│   ├── Dockerfile
│   └── entrypoint.sh
├── engine/
│   ├── mcts_player.sh
│   └── policy_player.sh
├── env_name.yml
├── requirements.txt
├── setup.py
├── src/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── mlm.py
│   │   └── policy_value.py
│   ├── features/
│   │   ├── __init__.py
│   │   ├── common.py
│   │   └── policy_value.py
│   ├── model/
│   │   ├── __init__.py
│   │   └── bert.py
│   ├── pl_modules/
│   │   ├── __init__.py
│   │   ├── mlm.py
│   │   └── policy_value.py
│   ├── player/
│   │   ├── __init__.py
│   │   ├── base_player.py
│   │   ├── mcts_player.py
│   │   ├── policy_player.py
│   │   └── usi.py
│   ├── uct/
│   │   ├── __init__.py
│   │   └── uct_node.py
│   └── utils/
│       ├── __init__.py
│       ├── hcpe.py
│       ├── misc.py
│       ├── sfen.py
│       └── shogi.py
└── tools/
    ├── download_and_build_lesserkai.sh
    ├── make_dataset.py
    ├── pl_to_transformers.py
    ├── test_engine.py
    └── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Project specific folders
lightning_logs/
tf_logs/
remote_logs/
local_logs/
*_logs/
checkpoints/
dataset/
local_misc/
notebooks/data/

# LSF logfiles
lsf.*

# IDEA folders
.idea/

/work_dirs/*
/data/*
!.gitkeep


================================================
FILE: Makefile
================================================
PROJECT ?= bert-mcts
DATADIR ?= ${PWD}/data
WORKSPACE ?= /workspace/$(PROJECT)
DOCKER_IMAGE ?= ${PROJECT}:latest

SHMSIZE ?= 100G
DOCKER_OPTS := \
			--name ${PROJECT} \
			--rm -it \
			--shm-size=${SHMSIZE} \
			-v ${PWD}:${WORKSPACE} \
			-v ${DATADIR}:${WORKSPACE}/data \
			-v ${LOG_DIR}:${WORKSPACE}/work_dirs/logs \
      -w ${WORKSPACE} \
			--ipc=host \
			--network=host \
			--gpus all

docker-build:
	docker build -f docker/Dockerfile -t ${DOCKER_IMAGE} .

docker-start-interactive: docker-build
	docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash

docker-start-jupyter: docker-build
	docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \
		bash -c "jupyter lab --port=8888 --ip=0.0.0.0 --allow-root --no-browser"

docker-run: docker-build
	docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \
		bash -c "${COMMAND}"


================================================
FILE: README.md
================================================
# BERT-MCTS-YOUTUBE

YouTubeにてヨビノリたくみさんと対戦した将棋ソフトです。
自然言語モデルであるBERTとモンテカルロ木探索(MCTS)の組み合わせで出来ています。  
すべてpythonで書いてあるため、探索の速度は遅いです。  

BERT以外の大部分は『将棋AIで学ぶディープラーニング』を参考に書いています。
- [書籍(amazon)](https://www.amazon.co.jp/dp/B07B7JJ929)
- [github](https://github.com/TadaoYamaoka/python-dlshogi)

## 環境

### Colab

テストするだけなら[google colab](https://colab.research.google.com/drive/10KAuLlNe6FKZBp_iE2bQJPNhoY2WeACx?usp=sharing) が簡単です。

以下はローカルで試す場合。CPUだと遅いのでCUDA環境が望ましいです。  
重みファイルは[ここ](https://drive.google.com/drive/folders/1N-Np2NmNLtLGS9gjnreBkYdTxrH1EHFw?usp=sharing) にアップしてあり、
たくみさんと戦った重みファイルがyoutube_version.ckpt、追加で数日間学習させた重みファイルがlatest.ckptになります。  
ダウンロード先のパスはengine/***_player.sh内で指定してください。
デフォルトではwork_dirs以下にダウンロードすることを想定しています。

### Docker

cuda10.2以上のnvidia-dockerが整っているなら次のコマンドで環境に入れます。
```bash
$ make docker-start-interactive
```

### Ubuntu18.04

cuda10.2でanacondaが入っていれば次のコマンドで仮想環境を作れます。
```bash
$ conda env create -f env_name.yml
$ conda activate bert-mcts-youtube
$ python setup.py develop
```

### Windows10

未検証

## 将棋エンジンのテスト

エンジンはengineディレクトリに用意しています。これらはShogiGUIなどから呼び出すことができます。
- policy_player.shはBERTの出力する方策のみを頼りに指すモデル(弱い)
- mcts_player.shはBERTの出力をもとにMCTSで探索するモデル

## 学習

学習には互角局面集とGCTの自己対戦棋譜を用いました。  
モデルはMasked Language Modelで事前学習してから、Policy Value Networkの学習という手順を踏みます。  
ただし、将棋は良質な教師データが大量にあるため事前学習の効果はあまりない気がします。

### データの準備

互角局面集のダウンロード
```bash
$ cd data
$ git clone https://github.com/tttak/ShogiGokakuKyokumen.git
```

GCTの自己対戦棋譜
```bash
$ cd data
$ mkdir hcpe
```

GCTの自己対戦棋譜は開発者の加納さんが[リンク](https://drive.google.com/drive/folders/14FaqqIHRctTQIY6hScCFXWQQZ_pSU3-F)
に公開してくださっていまし。
ここからselfplay-***となっているファイルをいくつかdata/hcpe以下にダウンロードしてください。  
サイズが大きいので一個でも十分な量あります。

これらを準備できたら以下のコマンドでデータセットを作ります。

```bash
$ python tools/make_dataset.py
```

### Masked Language Modelの学習

```bash
$ python tools/train.py configs/mlm_base.yaml
```

### 重みファイルの変換

Masked Language Modelのチェックポイントをtransformers形式に変換しておきます。
これによって転移学習のコードが多少書きやすくなります。

```bash
$ python tools/pl_to_transformers.py work_dirs/mlm_base/version_0/checkpoints/last.ckpt
```

### Policy Value Modelの学習

最後にこれらを使ってPolicy Valueを学習させます。

```bash
$ python tools/train.py configs/policy_value.yaml
```

================================================
FILE: configs/.gitkeep
================================================


================================================
FILE: configs/mlm_base.yaml
================================================
model_type: 'MLM'

seed: 42
dataset_dir: './data/dataset/gokaku_100'
model_dir:

train_loader:
  batch_size: 64
  shuffle: True
  num_workers: 8
  pin_memory: False
  drop_last: True

val_loader:
  batch_size: 64
  shuffle: False
  num_workers: 8
  pin_memory: False
  drop_last: False

train_params:
  max_epochs: 5
  # validationおよびcheckpointの間隔step数
  val_check_interval: 3000
  # 環境に応じて,適宜変更
  gpus: [0]

================================================
FILE: configs/policy_value_base.yaml
================================================
model_type: 'PolicyValue'

seed: 42
dataset_dir: './data/dataset/selfplay'
model_dir: './work_dirs/mlm_base/version_0/checkpoints'

train_loader:
  batch_size: 128
  shuffle: True
  num_workers: 0
  pin_memory: True
  drop_last: True

val_loader:
  batch_size: 128
  shuffle: False
  num_workers: 0
  pin_memory: True
  drop_last: False

train_params:
  max_epochs: 1
  # validationおよびcheckpointの間隔step数
  val_check_interval: 30000
  limit_val_batches: 0.1
  # 環境に応じて,適宜変更
  gpus: [0]


================================================
FILE: docker/Dockerfile
================================================
FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
ENV DEBIAN_FRONTEND=noninteractive
ADD requirements.txt /tmp
RUN pip install -r /tmp/requirements.txt

ADD docker/entrypoint.sh /tmp
ENTRYPOINT ["bash", "/tmp/entrypoint.sh"]


================================================
FILE: docker/entrypoint.sh
================================================
python setup.py develop
exec "$@"


================================================
FILE: engine/mcts_player.sh
================================================
#!/bin/sh
python -m src.player.mcts_player --ckpt_path ./work_dirs/youtube_version.ckpt


================================================
FILE: engine/policy_player.sh
================================================
#!/bin/sh
python -m src.player.policy_player --ckpt_path ./work_dirs/youtube_version.ckpt


================================================
FILE: env_name.yml
================================================
name: bert-mcts-youtube
channels:
  - pytorch
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - alabaster=0.7.12=py37_0
  - anaconda=2020.11=py37_0
  - anaconda-client=1.7.2=py37_0
  - anaconda-project=0.8.4=py_0
  - argh=0.26.2=py37_0
  - argon2-cffi=20.1.0=py37h7b6447c_1
  - asn1crypto=1.4.0=py_0
  - astroid=2.4.2=py37_0
  - astropy=4.0.2=py37h7b6447c_0
  - async_generator=1.10=py37h28b3542_0
  - atomicwrites=1.4.0=py_0
  - attrs=20.3.0=pyhd3eb1b0_0
  - autopep8=1.5.4=py_0
  - babel=2.8.1=pyhd3eb1b0_0
  - backcall=0.2.0=py_0
  - backports=1.0=py_2
  - backports.shutil_get_terminal_size=1.0.0=py37_2
  - beautifulsoup4=4.9.3=pyhb0f4dca_0
  - bitarray=1.6.1=py37h27cfd23_0
  - bkcharts=0.2=py37_0
  - blas=1.0=mkl
  - bleach=3.2.1=py_0
  - blosc=1.20.1=hd408876_0
  - bokeh=2.2.3=py37_0
  - boto=2.49.0=py37_0
  - bottleneck=1.3.2=py37heb32a55_1
  - brotlipy=0.7.0=py37h7b6447c_1000
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2020.10.14=0
  - cairo=1.14.12=h8948797_3
  - certifi=2020.6.20=pyhd3eb1b0_3
  - cffi=1.14.3=py37he30daa8_0
  - chardet=3.0.4=py37_1003
  - click=7.1.2=py_0
  - cloudpickle=1.6.0=py_0
  - clyent=1.2.2=py37_1
  - colorama=0.4.4=py_0
  - contextlib2=0.6.0.post1=py_0
  - cryptography=3.1.1=py37h1ba5d50_0
  - cudatoolkit=10.2.89=hfd86e86_1
  - curl=7.71.1=hbc83047_1
  - cycler=0.10.0=py37_0
  - cython=0.29.21=py37he6710b0_0
  - cytoolz=0.11.0=py37h7b6447c_0
  - dask=2.30.0=py_0
  - dask-core=2.30.0=py_0
  - dbus=1.13.18=hb2f20db_0
  - decorator=4.4.2=py_0
  - defusedxml=0.6.0=py_0
  - diff-match-patch=20200713=py_0
  - distributed=2.30.1=py37h06a4308_0
  - docutils=0.16=py37_1
  - entrypoints=0.3=py37_0
  - et_xmlfile=1.0.1=py_1001
  - expat=2.2.10=he6710b0_2
  - fastcache=1.1.0=py37h7b6447c_0
  - ffmpeg=4.3=hf484d3e_0
  - filelock=3.0.12=py_0
  - flake8=3.8.4=py_0
  - flask=1.1.2=py_0
  - fontconfig=2.13.0=h9420a91_0
  - freetype=2.10.4=h5ab3b9f_0
  - fribidi=1.0.10=h7b6447c_0
  - fsspec=0.8.3=py_0
  - future=0.18.2=py37_1
  - get_terminal_size=1.0.0=haa9412d_0
  - gevent=20.9.0=py37h7b6447c_0
  - glib=2.66.1=h92f7085_0
  - glob2=0.7=py_0
  - gmp=6.1.2=h6c8ec71_1
  - gmpy2=2.0.8=py37h10f8cd9_2
  - gnutls=3.6.15=he1e5248_0
  - graphite2=1.3.14=h23475e2_0
  - greenlet=0.4.17=py37h7b6447c_0
  - gst-plugins-base=1.14.0=hbbd80ab_1
  - gstreamer=1.14.0=hb31296c_0
  - h5py=2.10.0=py37h7918eee_0
  - harfbuzz=2.4.0=hca77d97_1
  - hdf5=1.10.4=hb1b8bf9_0
  - heapdict=1.0.1=py_0
  - html5lib=1.1=py_0
  - icu=58.2=he6710b0_3
  - idna=2.10=py_0
  - imageio=2.9.0=py_0
  - imagesize=1.2.0=py_0
  - importlib-metadata=2.0.0=py_1
  - importlib_metadata=2.0.0=1
  - iniconfig=1.1.1=py_0
  - intel-openmp=2020.2=254
  - intervaltree=3.1.0=py_0
  - ipykernel=5.3.4=py37h5ca1d4c_0
  - ipython=7.19.0=py37hb070fc8_0
  - ipython_genutils=0.2.0=py37_0
  - ipywidgets=7.5.1=py_1
  - isort=5.6.4=py_0
  - itsdangerous=1.1.0=py37_0
  - jbig=2.1=hdba287a_0
  - jdcal=1.4.1=py_0
  - jedi=0.17.1=py37_0
  - jeepney=0.5.0=pyhd3eb1b0_0
  - jinja2=2.11.2=py_0
  - joblib=0.17.0=py_0
  - jpeg=9b=h024ee3a_2
  - json5=0.9.5=py_0
  - jsonschema=3.2.0=py_2
  - jupyter=1.0.0=py37_7
  - jupyter_client=6.1.7=py_0
  - jupyter_console=6.2.0=py_0
  - jupyter_core=4.6.3=py37_0
  - jupyterlab=2.2.6=py_0
  - jupyterlab_pygments=0.1.2=py_0
  - jupyterlab_server=1.2.0=py_0
  - keyring=21.4.0=py37_1
  - kiwisolver=1.3.0=py37h2531618_0
  - krb5=1.18.2=h173b8e3_0
  - lame=3.100=h7b6447c_0
  - lazy-object-proxy=1.4.3=py37h7b6447c_0
  - lcms2=2.11=h396b838_0
  - ld_impl_linux-64=2.33.1=h53a641e_7
  - libarchive=3.4.2=h62408e4_0
  - libcurl=7.71.1=h20c2e04_1
  - libedit=3.1.20191231=h14c3975_1
  - libffi=3.3=he6710b0_2
  - libgcc-ng=9.1.0=hdf63c60_0
  - libgfortran-ng=7.3.0=hdf63c60_0
  - libiconv=1.15=h63c8f33_5
  - libidn2=2.3.0=h27cfd23_0
  - liblief=0.10.1=he6710b0_0
  - libllvm10=10.0.1=hbcb73fb_5
  - libpng=1.6.37=hbc83047_0
  - libsodium=1.0.18=h7b6447c_0
  - libspatialindex=1.9.3=he6710b0_0
  - libssh2=1.9.0=h1ba5d50_1
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - libtasn1=4.16.0=h27cfd23_0
  - libtiff=4.1.0=h2733197_1
  - libtool=2.4.6=h7b6447c_1005
  - libunistring=0.9.10=h27cfd23_0
  - libuuid=1.0.3=h1bed415_2
  - libuv=1.40.0=h7b6447c_0
  - libxcb=1.14=h7b6447c_0
  - libxml2=2.9.10=hb55368b_3
  - libxslt=1.1.34=hc22bd24_0
  - llvmlite=0.34.0=py37h269e1b5_4
  - locket=0.2.0=py37_1
  - lxml=4.6.1=py37hefd8a0e_0
  - lz4-c=1.9.2=heb0550a_3
  - lzo=2.10=h7b6447c_2
  - markupsafe=1.1.1=py37h14c3975_1
  - matplotlib=3.3.2=0
  - matplotlib-base=3.3.2=py37h817c723_0
  - mccabe=0.6.1=py37_1
  - mistune=0.8.4=py37h14c3975_1001
  - mkl=2020.2=256
  - mkl-service=2.3.0=py37he904b0f_0
  - mkl_fft=1.2.0=py37h23d657b_0
  - mkl_random=1.1.1=py37h0573a6f_0
  - mock=4.0.2=py_0
  - more-itertools=8.6.0=pyhd3eb1b0_0
  - mpc=1.1.0=h10f8cd9_1
  - mpfr=4.0.2=hb69a4c5_1
  - mpmath=1.1.0=py37_0
  - msgpack-python=1.0.0=py37hfd86e86_1
  - multipledispatch=0.6.0=py37_0
  - nbclient=0.5.1=py_0
  - nbconvert=6.0.7=py37_0
  - nbformat=5.0.8=py_0
  - ncurses=6.2=he6710b0_1
  - nest-asyncio=1.4.2=pyhd3eb1b0_0
  - nettle=3.7.2=hbbd107a_1
  - networkx=2.5=py_0
  - ninja=1.10.2=py37hff7bd54_0
  - nltk=3.5=py_0
  - nose=1.3.7=py37_1004
  - notebook=6.1.4=py37_0
  - numba=0.51.2=py37h04863e7_1
  - numexpr=2.7.1=py37h423224d_0
  - numpy=1.19.2=py37h54aff64_0
  - numpy-base=1.19.2=py37hfa32c7d_0
  - numpydoc=1.1.0=pyhd3eb1b0_1
  - olefile=0.46=py37_0
  - openh264=2.1.0=hd408876_0
  - openpyxl=3.0.5=py_0
  - openssl=1.1.1h=h7b6447c_0
  - packaging=20.4=py_0
  - pandas=1.1.3=py37he6710b0_0
  - pandoc=2.11=hb0f4dca_0
  - pandocfilters=1.4.3=py37h06a4308_1
  - pango=1.45.3=hd140c19_0
  - parso=0.7.0=py_0
  - partd=1.1.0=py_0
  - patchelf=0.12=he6710b0_0
  - path=15.0.0=py37_0
  - path.py=12.5.0=0
  - pathlib2=2.3.5=py37_1
  - pathtools=0.1.2=py_1
  - patsy=0.5.1=py37_0
  - pcre=8.44=he6710b0_0
  - pep8=1.7.1=py37_0
  - pexpect=4.8.0=py37_1
  - pickleshare=0.7.5=py37_1001
  - pillow=8.0.1=py37he98fc37_0
  - pip=20.2.4=py37h06a4308_0
  - pixman=0.40.0=h7b6447c_0
  - pkginfo=1.6.1=py37h06a4308_0
  - pluggy=0.13.1=py37_0
  - ply=3.11=py37_0
  - prometheus_client=0.8.0=py_0
  - prompt-toolkit=3.0.8=py_0
  - prompt_toolkit=3.0.8=0
  - psutil=5.7.2=py37h7b6447c_0
  - ptyprocess=0.6.0=py37_0
  - py=1.9.0=py_0
  - py-lief=0.10.1=py37h403a769_0
  - pycodestyle=2.6.0=py_0
  - pycosat=0.6.3=py37h7b6447c_0
  - pycparser=2.20=py_2
  - pycrypto=2.6.1=py37h7b6447c_10
  - pycurl=7.43.0.6=py37h1ba5d50_0
  - pydocstyle=5.1.1=py_0
  - pyflakes=2.2.0=py_0
  - pygments=2.7.2=pyhd3eb1b0_0
  - pylint=2.6.0=py37_0
  - pyodbc=4.0.30=py37he6710b0_0
  - pyopenssl=19.1.0=py_1
  - pyparsing=2.4.7=py_0
  - pyqt=5.9.2=py37h05f1152_2
  - pyrsistent=0.17.3=py37h7b6447c_0
  - pysocks=1.7.1=py37_1
  - pytables=3.6.1=py37h71ec239_0
  - pytest=6.1.1=py37_0
  - python=3.7.9=h7579374_0
  - python-dateutil=2.8.1=py_0
  - python-jsonrpc-server=0.4.0=py_0
  - python-language-server=0.35.1=py_0
  - python-libarchive-c=2.9=py_0
  - pytorch=1.8.1=py3.7_cuda10.2_cudnn7.6.5_0
  - pytz=2020.1=py_0
  - pywavelets=1.1.1=py37h7b6447c_2
  - pyxdg=0.27=pyhd3eb1b0_0
  - pyyaml=5.3.1=py37h7b6447c_1
  - pyzmq=19.0.2=py37he6710b0_1
  - qdarkstyle=2.8.1=py_0
  - qt=5.9.7=h5867ecd_1
  - qtawesome=1.0.1=py_0
  - qtconsole=4.7.7=py_0
  - qtpy=1.9.0=py_0
  - readline=8.0=h7b6447c_0
  - regex=2020.10.15=py37h7b6447c_0
  - requests=2.24.0=py_0
  - ripgrep=12.1.1=0
  - rope=0.18.0=py_0
  - rtree=0.9.4=py37_1
  - ruamel_yaml=0.15.87=py37h7b6447c_1
  - scikit-image=0.17.2=py37hdf5156a_0
  - scikit-learn=0.23.2=py37h0573a6f_0
  - scipy=1.5.2=py37h0b6359f_0
  - seaborn=0.11.0=py_0
  - secretstorage=3.1.2=py37_1
  - send2trash=1.5.0=py37_0
  - setuptools=50.3.1=py37h06a4308_1
  - simplegeneric=0.8.1=py37_2
  - singledispatch=3.4.0.3=py_1001
  - sip=4.19.8=py37hf484d3e_0
  - six=1.15.0=py37h06a4308_0
  - snowballstemmer=2.0.0=py_0
  - sortedcollections=1.2.1=py_0
  - sortedcontainers=2.2.2=py_0
  - soupsieve=2.0.1=py_0
  - sphinx=3.2.1=py_0
  - sphinxcontrib=1.0=py37_1
  - sphinxcontrib-applehelp=1.0.2=py_0
  - sphinxcontrib-devhelp=1.0.2=py_0
  - sphinxcontrib-htmlhelp=1.0.3=py_0
  - sphinxcontrib-jsmath=1.0.1=py_0
  - sphinxcontrib-qthelp=1.0.3=py_0
  - sphinxcontrib-serializinghtml=1.1.4=py_0
  - sphinxcontrib-websupport=1.2.4=py_0
  - spyder=4.1.5=py37_0
  - spyder-kernels=1.9.4=py37_0
  - sqlalchemy=1.3.20=py37h27cfd23_0
  - sqlite=3.33.0=h62c20be_0
  - statsmodels=0.12.0=py37h7b6447c_0
  - sympy=1.6.2=py37h06a4308_1
  - tbb=2020.3=hfd86e86_0
  - tblib=1.7.0=py_0
  - terminado=0.9.1=py37_0
  - testpath=0.4.4=py_0
  - threadpoolctl=2.1.0=pyh5ca1d4c_0
  - tifffile=2020.10.1=py37hdd07704_2
  - tk=8.6.10=hbc83047_0
  - toml=0.10.1=py_0
  - toolz=0.11.1=py_0
  - torchvision=0.9.1=py37_cu102
  - tornado=6.0.4=py37h7b6447c_1
  - tqdm=4.50.2=py_0
  - traitlets=5.0.5=py_0
  - typed-ast=1.4.1=py37h7b6447c_0
  - typing_extensions=3.7.4.3=py_0
  - ujson=4.0.1=py37he6710b0_0
  - unicodecsv=0.14.1=py37_0
  - unixodbc=2.3.9=h7b6447c_0
  - urllib3=1.25.11=py_0
  - watchdog=0.10.3=py37_0
  - wcwidth=0.2.5=py_0
  - webencodings=0.5.1=py37_1
  - werkzeug=1.0.1=py_0
  - wheel=0.35.1=py_0
  - widgetsnbextension=3.5.1=py37_0
  - wrapt=1.11.2=py37h7b6447c_0
  - wurlitzer=2.0.1=py37_0
  - xlrd=1.2.0=py37_0
  - xlsxwriter=1.3.7=py_0
  - xlwt=1.3.0=py37_0
  - xz=5.2.5=h7b6447c_0
  - yaml=0.2.5=h7b6447c_0
  - yapf=0.30.0=py_0
  - zeromq=4.3.3=he6710b0_3
  - zict=2.0.0=py_0
  - zipp=3.4.0=pyhd3eb1b0_0
  - zlib=1.2.11=h7b6447c_3
  - zope=1.0=py37_1
  - zope.event=4.5.0=py37_0
  - zope.interface=5.1.2=py37h7b6447c_0
  - zstd=1.4.5=h9ceee32_0
  - pip:
    - absl-py==0.12.0
    - aiohttp==3.7.4.post0
    - async-timeout==3.0.1
    - cachetools==4.2.1
    - cshogi==0.2.4
    - google-auth==1.28.1
    - google-auth-oauthlib==0.4.4
    - grpcio==1.37.0
    - markdown==3.3.4
    - multidict==5.1.0
    - oauthlib==3.1.0
    - omegaconf==2.0.6
    - protobuf==3.15.8
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - pytorch-lightning==1.2.7
    - requests-oauthlib==1.3.0
    - rsa==4.7.2
    - sacremoses==0.0.44
    - tensorboard==2.4.1
    - tensorboard-plugin-wit==1.8.0
    - tokenizers==0.10.2
    - torchmetrics==0.2.0
    - transformers==4.5.1
    - yarl==1.6.3
prefix: /home/charmer/.pyenv/versions/anaconda3-2020.07/envs/bert-mcts


================================================
FILE: requirements.txt
================================================
--find-links https://download.pytorch.org/whl/torch_stable.html
torch>=1.6
pytorch-lightning==1.2.7
transformers>=4.5
cshogi
omegaconf

================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension

if __name__ == '__main__':
    setup(
        name='bert-mcts',
        description='BERT-driven Shogi AI',
        author='Hiroki Taniai',
        author_email='charmer.popopo@gmail.com',
        packages=find_packages(include=['src']),
        cmdclass={'build_ext': BuildExtension},
    )


================================================
FILE: src/data/__init__.py
================================================



================================================
FILE: src/data/mlm.py
================================================
import numpy as np
import torch
from torch.utils.data import Dataset

from src.utils.shogi import pieces_list


# Masked Language Model
class MLMDataset(Dataset):
    def __init__(self, data):
        self.seqs = data['seq']
        self.mask_token_id = 32  # 32は駒が割り振られていないid
        assert self.mask_token_id not in np.array(pieces_list)

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        inputs = np.array(self.seqs[idx])
        labels = inputs.copy()

        # 予想対象
        masked_indices = np.random.random(labels.shape) < 0.15
        labels[~masked_indices] = -100

        # 80%はマスクトークンに
        indices_replaced = (np.random.random(labels.shape) < 0.8) & masked_indices
        inputs[indices_replaced] = self.mask_token_id

        # 10%はランダムに置き換え
        indices_random = (np.random.random(labels.shape) < 0.5) & masked_indices & ~indices_replaced
        random_words = np.random.choice(pieces_list, labels.shape)
        inputs[indices_random] = random_words[indices_random]

        # 残り10%はそのままのものが残る

        ret_dict = {'input_ids': torch.tensor(inputs, dtype=torch.long),
                    'labels': torch.tensor(labels, dtype=torch.long)}
        return ret_dict


================================================
FILE: src/data/policy_value.py
================================================
import torch
from torch.utils.data import Dataset


class PolicyValueDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        dt = self.data[idx]

        ret_dict = {'input_ids': torch.tensor(dt['seq'], dtype=torch.long),
                    'labels': torch.tensor(dt['label'], dtype=torch.long),
                    'values': torch.tensor(dt['value'], dtype=torch.float),
                    'result': torch.tensor(dt['result'], dtype=torch.long)}
        return ret_dict


================================================
FILE: src/features/__init__.py
================================================



================================================
FILE: src/features/common.py
================================================
import cshogi

from src.utils.shogi import reverse_piece_fn


def get_seq_from_board(board):
    bp, wp = board.pieces_in_hand
    if board.turn == cshogi.BLACK:
        seq = board.pieces + bp + wp
    else:
        # 駒の順番を逆にして、駒を先後反転させる。持ち駒は逆にする。
        seq = reverse_piece_fn(board.pieces[::-1]).tolist() + wp + bp
    return seq


================================================
FILE: src/features/policy_value.py
================================================
import cshogi
import numpy as np
from cshogi import move_drop_hand_piece, move_from, move_is_drop, move_is_promotion, move_to

from src.features.common import get_seq_from_board
from src.utils.shogi import (DOWN, DOWN_LEFT, DOWN_RIGHT, LEFT, MOVE_DIRECTION, MOVE_DIRECTION_PROMOTED, RIGHT, UP,
                             UP2_LEFT, UP2_RIGHT, UP_LEFT, UP_RIGHT)

board = cshogi.Board()


def get_move_label(move, color):
    if not move_is_drop(move):
        from_sq = move_from(move)
        to_sq = move_to(move)
        if color == cshogi.WHITE:
            to_sq = 80 - to_sq
            from_sq = 80 - from_sq

        # file: 筋, rank: 段
        from_file, from_rank = divmod(from_sq, 9)
        to_file, to_rank = divmod(to_sq, 9)
        dir_file = to_file - from_file
        dir_rank = to_rank - from_rank
        if dir_rank < 0 and dir_file == 0:
            move_direction = UP
        elif dir_rank == -2 and dir_file == -1:
            move_direction = UP2_RIGHT
        elif dir_rank == -2 and dir_file == 1:
            move_direction = UP2_LEFT
        elif dir_rank < 0 and dir_file < 0:
            move_direction = UP_RIGHT
        elif dir_rank < 0 and dir_file > 0:
            move_direction = UP_LEFT
        elif dir_rank == 0 and dir_file < 0:
            move_direction = RIGHT
        elif dir_rank == 0 and dir_file > 0:
            move_direction = LEFT
        elif dir_rank > 0 and dir_file == 0:
            move_direction = DOWN
        elif dir_rank > 0 and dir_file < 0:
            move_direction = DOWN_RIGHT
        elif dir_rank > 0 and dir_file > 0:
            move_direction = DOWN_LEFT
        else:
            raise RuntimeError

        # promote
        if move_is_promotion(move):
            move_direction = MOVE_DIRECTION_PROMOTED[move_direction]

    else:
        # 持ち駒
        move_direction = len(MOVE_DIRECTION) + move_drop_hand_piece(move) - 1
        to_sq = move_to(move)
        if color == cshogi.WHITE:
            to_sq = 80 - to_sq

    # labelのmaxは27*81-1=2186
    move_label = move_direction * 81 + to_sq
    return move_label


def get_result(result, color):
    # 引き分け
    if result == 0:
        return 0.5
    # 手番が勝ち
    elif ((result == 1) and (color == cshogi.BLACK)) or ((result == 2) and (color == cshogi.WHITE)):
        return 1
    else:
        return 0


def get_policy_value_label(hcpe):
    board.reset()
    board.set_hcp(hcpe['hcp'])

    seq = get_seq_from_board(board)
    label = get_move_label(hcpe['bestMove16'], board.turn)
    value = 1 / (1 + np.exp(-hcpe['eval'] * 0.0013226))
    result = get_result(hcpe['gameResult'], board.turn)

    return seq, label, value, result


def get_policy_value_label_from_moves(moves):
    n = np.random.randint(4, len(moves)-1)
    board.reset()
    for move in moves[:n]:
        board.push(move)
    seq = get_seq_from_board(board)
    label = get_move_label(moves[n], board.turn)
    value = 0.5
    result = 0.5
    return seq, label, value, result


def get_moves_from_lines(line):
    board.reset()
    moves = []
    for move_usi in line.split()[2:]:
        move = board.move_from_usi(move_usi)
        board.push(move)
        moves.append(move)
    return moves


================================================
FILE: src/model/__init__.py
================================================



================================================
FILE: src/model/bert.py
================================================
import torch.nn as nn
from transformers import BertConfig, BertForMaskedLM, BertModel

from src.utils.shogi import pieces_list

config = {
    'vocab_size': len(pieces_list) + 4,  # MASK_TOKEN_ID, MASK, CLS, SEP
    'hidden_size': 768,
    'num_hidden_layers': 12,
    'num_attention_heads': 12,
    'intermediate_size': 3072,  # hidden_size * 4が目安
    'hidden_act': 'gelu',
    'hidden_dropout_prob': 0.1,
    'attention_probs_dropout_prob': 0.1,
    'max_position_embeddings': 512,  # 95(=81(マス目)+7(先手持駒)+7(後手持駒))でいいかも
    'type_vocab_size': 1,  # 対の文章を入れない。つまりtoken_type_embeddingsは完全に無駄になっている。
    'initializer_range': 0.02,
}
config = BertConfig.from_dict(config)


class BertMLM(nn.Module):
    def __init__(self, model_dir=None):
        super().__init__()
        if model_dir is None:
            self.bert = BertForMaskedLM(config)
        else:
            self.bert = BertForMaskedLM.from_pretrained(model_dir)

    def forward(self, input_ids, labels):
        return self.bert(input_ids=input_ids, labels=labels)


class BertPolicyValue(nn.Module):
    def __init__(self, model_dir=None):
        super().__init__()
        if model_dir is None:
            self.bert = BertModel(config)
        else:
            self.bert = BertModel.from_pretrained(model_dir)

        self.policy_head = nn.Sequential(
            nn.Linear(768, 768 * 2),
            nn.Tanh(),
            nn.Linear(768 * 2, 9 * 9 * 27)
        )

        self.value_head = nn.Sequential(
            nn.Linear(768, 768 * 2),
            nn.Tanh(),
            nn.Linear(768 * 2, 1),
            nn.Sigmoid()
        )

        self.loss_policy_fn = nn.CrossEntropyLoss()
        self.loss_value_fn = nn.MSELoss()

    def forward(self, input_ids, labels=None):
        features = self.bert(input_ids=input_ids)['last_hidden_state']
        policy = self.policy_head(features).mean(axis=1)
        value = self.value_head(features).mean(axis=1).squeeze(1)
        if labels is None:
            return {'policy': policy, 'value': value}
        else:
            loss_policy = self.loss_policy_fn(policy, labels['labels'])
            loss_value = self.loss_value_fn(value, labels['values'])
            loss = loss_policy + loss_value
            return {'loss_policy': loss_policy, 'loss_value': loss_value, 'loss': loss}


================================================
FILE: src/pl_modules/__init__.py
================================================
from .mlm import MLMModule, MLMDataModule
from .policy_value import PolicyValueModule, PolicyValueDataModule


def get_pl_modules(cfg):
    if cfg.model_type == 'MLM':
        return MLMModule(cfg), MLMDataModule(cfg)
    elif cfg.model_type == 'PolicyValue':
        return PolicyValueModule(cfg), PolicyValueDataModule(cfg)
    else:
        raise NotImplementedError


================================================
FILE: src/pl_modules/mlm.py
================================================
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import AdamW

from src.data.mlm import MLMDataset
from src.model.bert import BertMLM


class MLMModule(pl.LightningModule):

    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.model = BertMLM(hparams['model_dir'])

    def forward(self, batch):
        input_ids = batch['input_ids']
        labels = batch['labels']
        return self.model(input_ids=input_ids, labels=labels)

    def training_step(self, batch, batch_idx):
        outputs = self(batch)
        loss = outputs[0]
        self.log('loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(batch)
        loss = outputs[0].detach().cpu().numpy()
        return {'loss': loss}

    def validation_epoch_end(self, outputs):
        val_loss = np.mean([out['loss'] for out in outputs])
        self.log('steps', self.global_step)
        self.log('val_loss', val_loss)

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=5e-5)


class MLMDataModule(pl.LightningDataModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

    def setup(self, stage=None):
        dataset_dir = Path(self.cfg.dataset_dir)
        train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True)
        valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True)
        self.train_dataset = MLMDataset(train_data)
        self.val_dataset = MLMDataset(valid_data)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, **self.cfg.train_loader)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, **self.cfg.val_loader)


================================================
FILE: src/pl_modules/policy_value.py
================================================
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import AdamW

from src.data.policy_value import PolicyValueDataset
from src.model.bert import BertPolicyValue


class PolicyValueModule(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.model = BertPolicyValue(hparams['model_dir'])

    def forward(self, input_ids, labels=None):
        output = self.model(input_ids, labels)
        return output

    def training_step(self, batch, batch_idx):
        input_ids = batch.pop('input_ids')
        output = self(input_ids, batch)
        return {'loss': output['loss']}

    def validation_step(self, batch, batch_idx):
        input_ids = batch.pop('input_ids')
        output = self(input_ids, batch)
        for k, v in output.items():
            output[k] = v.detach().cpu().numpy()
        return output

    def validation_epoch_end(self, outputs):
        val_loss = np.mean([out['loss'] for out in outputs])
        val_loss_policy = np.mean([out['loss_policy'] for out in outputs])
        val_loss_value = np.mean([out['loss_value'] for out in outputs])
        self.log('val_loss', val_loss)
        self.log('val_loss_policy', val_loss_policy)
        self.log('val_loss_value', val_loss_value)

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=5e-5)


class PolicyValueDataModule(pl.LightningDataModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

    def setup(self, stage=None):
        dataset_dir = Path(self.cfg.dataset_dir)
        train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True)
        print('Load train data')
        valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True)
        print('Load val data')
        self.train_dataset = PolicyValueDataset(train_data)
        self.val_dataset = PolicyValueDataset(valid_data)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, **self.cfg.train_loader)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, **self.cfg.val_loader)



================================================
FILE: src/player/__init__.py
================================================



================================================
FILE: src/player/base_player.py
================================================
import cshogi


class BasePlayer:
    def __init__(self):
        self.board = cshogi.Board()

    def usi(self):
        print('id name bert_player')
        print('usiok')

    def usinewgame(self):
        pass

    def setoption(self, option):
        pass

    def isready(self):
        pass

    def position(self, moves):
        if moves[0] == 'startpos':
            self.board.reset()
            for move in moves[2:]:
                self.board.push_usi(move)
        elif moves[0] == 'sfen':
            self.board.set_sfen(' '.join(moves[1:]))
        # for debug
        print(self.board.sfen())

    def go(self):
        pass

    def quit(self):
        pass


================================================
FILE: src/player/mcts_player.py
================================================
import time
from argparse import ArgumentParser
from pathlib import Path

import cshogi
import numpy as np
import torch

from src.features.common import get_seq_from_board
from src.features.policy_value import get_move_label
from src.pl_modules.policy_value import PolicyValueModule
from src.player.base_player import BasePlayer
from src.player.usi import usi
from src.uct.uct_node import NOT_EXPANDED, NodeHash, UCT_HASH_SIZE, UctNode
from src.utils.misc import boltzmann


class MCTSPlayer(BasePlayer):
    def __init__(self, ckpt_path, playout_halt=1000, temperature=1, resign_threshold=0.01, c_puct=1):
        super().__init__()
        self.ckpt_path = ckpt_path
        self.playout_halt = playout_halt
        self.temperature = temperature
        self.resign_threshold = resign_threshold
        self.c_puct = c_puct

        self.model = None
        self.node_hash = NodeHash()
        self.uct_nodes = [UctNode() for _ in range(UCT_HASH_SIZE)]
        self.current_n_idx = None
        self.playout_count = 0

    def usi(self):
        print('id name BERT-MCTS')
        print('usiok')

    def isready(self):
        if self.model is None:
            self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model
            self.model.cuda()
            self.model.eval()
        self.node_hash.initialize()
        print('readyok')

    def go(self):
        if self.board.is_game_over():
            print('bestmove resign')
            return

        # 探索情報をクリア
        self.playout_count = 0

        # 古いハッシュを削除
        self.node_hash.delete_old_hash(self.board, self.uct_nodes)

        # 探索開始時刻
        begin_time = time.time()

        # ルートノードの展開
        self.current_n_idx = self.expand_node()

        # 候補手が1つの場合はその手を返す
        current_node = self.uct_nodes[self.current_n_idx]
        child_moves = current_node.child_moves
        child_num = len(child_moves)
        if child_num == 1:
            print('bestmove', cshogi.move_to_usi(child_moves[0]))
            return

        def get_bestmove_and_print_info():
            # 探索にかかった時間を求める
            finish_time = time.time() - begin_time

            if self.board.move_number < 5:
                selected_index = np.random.choice(np.arange(child_num), p=current_node.policy)
            else:
                # 訪問回数最大の手を選択する
                selected_index = np.argmax(current_node.child_moves_count)

            # 選択したノードの訪問回数0ならポリシーの値
            if current_node.child_moves_count[selected_index] == 0:
                best_wp = current_node.policy[selected_index]
            # それ以外なら勝率の平均を出す
            else:
                best_wp = current_node.child_value_sum[selected_index] / current_node.child_moves_count[selected_index]

            # 閾値以下なら投了
            if best_wp < self.resign_threshold:
                bestmove = 'resign'
            else:
                bestmove = cshogi.move_to_usi(child_moves[selected_index])

            # valueを評価値のスケールに変換
            if best_wp == 1:
                cp = 30000
            elif best_wp == 0:
                cp = -30000
            else:
                cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762)

            nps = int(current_node.move_count / finish_time)
            time_secs = int(finish_time * 1000)
            nodes = current_node.move_count
            hashfull = self.node_hash.get_usage_rate() * 100
            print(f'info score cp {cp} hashfull {hashfull:.2f} time {time_secs} nodes {nodes} nps {nps} pv {bestmove}')
            return bestmove

        # プレイアウトを繰り返す
        # 探索回数が閾値を超える、または探索が打ち切られたらループを抜ける
        while self.playout_count < self.playout_halt:
            self.playout_count += 1
            self.uct_search(self.current_n_idx)
            # 10回に1回 読みの状況を出力
            if (self.playout_count+1) % 10 == 0:
                get_bestmove_and_print_info()
            # 探索を打ち切るか確認
            if self.interruption_check() or not self.node_hash.enough_size:
                break

        bestmove = get_bestmove_and_print_info()
        print('bestmove', bestmove)

    def expand_node(self):
        current_hash = self.board.zobrist_hash()
        current_turn = self.board.turn
        current_move_number = self.board.move_number
        n_idx = self.node_hash.find_same_hash_index(current_hash, current_turn, current_move_number)

        # 合流先が検知できれば、それを返す
        if n_idx != UCT_HASH_SIZE:
            return n_idx

        # 空のインデックスを探す
        n_idx = self.node_hash.search_empty_index(current_hash, current_turn, current_move_number)

        # 現在のノードの初期化
        current_node = self.uct_nodes[n_idx]
        current_node.reset()

        # 候補手の展開
        current_node.child_moves = [move for move in self.board.legal_moves]
        child_num = len(current_node.child_moves)
        current_node.child_n_indices = [NOT_EXPANDED for _ in range(child_num)]
        current_node.child_moves_count = np.zeros(child_num, dtype=np.int32)
        current_node.child_value_sum = np.zeros(child_num, dtype=np.float32)

        # ノードを評価
        self.eval_node(n_idx)

        return n_idx

    def eval_node(self, n_idx):
        current_node = self.uct_nodes[n_idx]
        child_moves = current_node.child_moves
        child_num = len(current_node.child_moves)
        if child_num == 0:
            # 指す手がない=負け
            current_node.value = 0
            current_node.evaled = True
        else:
            # 現在の局面における方策と価値
            seq = get_seq_from_board(self.board)
            input_ids = torch.tensor([seq], dtype=torch.long).cuda()
            with torch.no_grad():
                output = self.model(input_ids)
                value = output['value'].detach().cpu().numpy()[0]
                policy_logits = output['policy'].detach().cpu().numpy()[0]

            # 合法手でフィルターする
            legal_move_labels = []
            for move in child_moves:
                legal_move_labels.append(get_move_label(move, self.board.turn))

            # Boltzmann
            policy = boltzmann(policy_logits[legal_move_labels], self.temperature)

            # ノードの値を更新
            current_node.policy = policy
            current_node.value = value
            current_node.evaled = True

    def uct_search(self, n_idx):
        current_node = self.uct_nodes[n_idx]
        child_moves = current_node.child_moves
        child_n_indices = current_node.child_n_indices
        child_num = len(child_moves)

        # 詰みは負け->元ノードから見たら勝ち
        if child_num == 0:
            return 1

        # PUCTアルゴリズムによるUCB値最大の手
        next_c_idx = self.select_max_ucb_child(n_idx)
        next_move = child_moves[next_c_idx]
        next_n_idx = child_n_indices[next_c_idx]

        # 選んだ手を着手
        self.board.push(next_move)

        if next_n_idx == NOT_EXPANDED:
            # 選択した手に対応するコードが未展開なら展開
            # ノードの展開(ノード展開処理の中でノードを評価する)
            next_n_idx = self.expand_node()
            child_n_indices[next_c_idx] = next_n_idx
            child_node = self.uct_nodes[next_n_idx]
            result = 1 - child_node.value
        else:
            # 展開済みなら一手深く読む
            result = self.uct_search(next_n_idx)

        # バックアップ
        # 探索結果の反映
        current_node.move_count += 1
        current_node.child_value_sum[next_c_idx] += result
        current_node.child_moves_count[next_c_idx] += 1
        # 手を戻す
        self.board.pop(next_move)

        return 1 - result

    def select_max_ucb_child(self, c_idx):
        current_node = self.uct_nodes[c_idx]
        child_wins_count = current_node.child_value_sum
        child_moves_count = current_node.child_moves_count
        child_num = len(child_moves_count)

        # child_move_countが0の場所のqは0.5で埋める
        q = np.divide(child_wins_count, child_moves_count, out=np.repeat(0.5, child_num), where=child_moves_count != 0)
        u = np.sqrt(current_node.move_count) / (1 + child_moves_count)
        ucb = q + self.c_puct * current_node.policy * u

        return np.argmax(ucb)

    # 探索を打ち切るか確認
    def interruption_check(self):
        child_move_count = self.uct_nodes[self.current_n_idx].child_moves_count
        rest = self.playout_halt - self.playout_count

        # 探索回数が最も多い手と次に多いてを求める
        second, first = child_move_count[np.argpartition(child_move_count, -2)[-2:]]

        # 残りの探索を全て次善手に費やしても最善手を超えられない場合は探索を打ち切る
        if first - second > rest:
            return True
        else:
            return False


def argparse():
    parser = ArgumentParser()
    parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt')
    args, _ = parser.parse_known_args()
    print('Command Line Args:')
    print(args)
    return args


def main(args):
    ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve()
    player = MCTSPlayer(ckpt_path)
    usi(player)


if __name__ == '__main__':
    args = argparse()
    main(args)


================================================
FILE: src/player/policy_player.py
================================================
from argparse import ArgumentParser
from pathlib import Path
import numpy as np

import cshogi
import torch
import torch.nn.functional as F

from src.features.common import get_seq_from_board
from src.features.policy_value import get_move_label
from src.pl_modules.policy_value import PolicyValueModule
from src.player.base_player import BasePlayer
from src.player.usi import usi
from src.utils.misc import greedy


class PolicyPlayer(BasePlayer):
    def __init__(self, ckpt_path):
        super().__init__()
        self.ckpt_path = ckpt_path
        self.model = None

    def usi(self):
        print('id name BERT-Policy')
        print('usiok')

    def isready(self):
        if self.model is None:
            self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model
            self.model.cuda()
            self.model.eval()
        print('readyok')

    def go(self):
        if self.board.is_game_over():
            print('bestmove resign')
            return

        seq = get_seq_from_board(self.board)
        input_ids = torch.tensor([seq], dtype=torch.long).cuda()

        with torch.no_grad():
            output = self.model(input_ids)
            policy = F.softmax(output['policy'], dim=1).detach().cpu().numpy()[0]

        # 全ての合法手について
        legal_moves = []
        legal_policy = []
        for move in self.board.legal_moves:
            # ラベルに変換
            label = get_move_label(move, self.board.turn)
            # 合法手とその指し手の確率(logits)を格納
            legal_moves.append(move)
            legal_policy.append(policy[label])

        selected_index = greedy(legal_policy)
        bestmove = cshogi.move_to_usi(legal_moves[selected_index])
        best_wp = legal_policy[selected_index]
        # valueを評価値のスケールに変換
        if best_wp == 1:
            cp = 30000
        elif best_wp == 0:
            cp = -30000
        else:
            cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762)

        print(f'info score cp {cp} pv {bestmove}')

        print('bestmove', bestmove)


def argparse():
    parser = ArgumentParser()
    parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt')
    args, _ = parser.parse_known_args()
    print('Command Line Args:')
    print(args)
    return args


def main(args):
    ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve()
    player = PolicyPlayer(ckpt_path)
    usi(player)


if __name__ == '__main__':
    args = argparse()
    main(args)


================================================
FILE: src/player/usi.py
================================================
from src.player.base_player import BasePlayer


def usi(player: BasePlayer):
    while True:
        cmd_line = input()
        cmd = cmd_line.split(' ', 1)
        cmd = [c.rstrip() for c in cmd]

        if cmd[0] == 'usi':
            player.usi()
        elif cmd[0] == 'setoption':
            option = cmd[1].split(' ')
            player.setoption(option)
        elif cmd[0] == 'isready':
            player.isready()
        elif cmd[0] == 'usinewgame':
            player.usinewgame()
        elif cmd[0] == 'position':
            moves = cmd[1].split(' ')
            player.position(moves)
        elif cmd[0] == 'go':
            player.go()
        elif cmd[0] == 'quit':
            player.quit()
            break


================================================
FILE: src/uct/__init__.py
================================================



================================================
FILE: src/uct/uct_node.py
================================================
# ノードの上限値
UCT_HASH_SIZE = 4096
# 未展開のノードのインデックス
NOT_EXPANDED = -1


# ゾブリストハッシュ値をUCT_HASH_SIZEに圧縮
def hash_to_index(zhash):
    return ((zhash & 0xffffffff) ^ ((zhash >> 32) & 0xffffffff)) & (UCT_HASH_SIZE - 1)


class NodeHashEntry:
    def __init__(self):
        self.hash = 0  # ゾブリストハッシュの値
        self.color = 0  # 手番
        self.moves = 0  # ゲーム開始からの手数
        self.flag = False  # 使用中か識別するフラグ

    def reset(self):
        self.hash = 0
        self.color = 0
        self.moves = 0
        self.flag = False


class NodeHash:
    def __init__(self):
        self.used = 0
        self.enough_size = True
        self.node_hash = None

    def initialize(self):
        self.used = 0
        self.enough_size = True

        if self.node_hash is None:
            self.node_hash = [NodeHashEntry() for _ in range(UCT_HASH_SIZE)]
        else:
            for i in range(UCT_HASH_SIZE):
                self.node_hash[i].reset()

    # 未使用のインデックスを探して返す
    def search_empty_index(self, zhash, color, moves):
        key = hash_to_index(zhash)
        i = key

        while True:
            if not self.node_hash[i].flag:
                self.node_hash[i].hash = zhash
                self.node_hash[i].color = color
                self.node_hash[i].moves = moves
                self.node_hash[i].flag = True
                self.used += 1
                if self.get_usage_rate() > 0.9:
                    self.enough_size = False
                return i
            i += 1
            if i >= UCT_HASH_SIZE:
                i = 0
            # もとに戻ってくる
            if i == key:
                return UCT_HASH_SIZE

    # ハッシュ値に対応するインデックスを返す
    def find_same_hash_index(self, zhash, color, moves):
        key = hash_to_index(zhash)
        i = key

        while True:
            # もろもろの属性があっていたらiを返す
            if self.node_hash[i].flag and self.node_hash[i].hash == zhash and self.node_hash[i].color == color and \
                    self.node_hash[i].moves == moves:
                return i
            else:
                return UCT_HASH_SIZE

    # 使用中のノードを残す
    def save_used_hash(self, board, uct_nodes, n_idx):
        self.node_hash[n_idx].flag = True
        self.used += 1

        current_node = uct_nodes[n_idx]
        child_n_indices = current_node.child_n_indices
        child_moves = current_node.child_moves
        child_num = len(child_moves)
        for i in range(child_num):
            if child_n_indices[i] != NOT_EXPANDED and not self.node_hash[child_n_indices[i]].flag:
                board.push(child_moves[i])
                self.save_used_hash(board, uct_nodes, child_n_indices[i])
                board.pop(child_moves[i])

    # 古いハッシュを削除
    def delete_old_hash(self, board, uct_node):
        # 現在の局面をルートとする局面以外を削除する
        n_idx = self.find_same_hash_index(board.zobrist_hash(), board.turn, board.move_number)

        self.used = 0
        for i in range(UCT_HASH_SIZE):
            self.node_hash[i].reset()

        if n_idx != UCT_HASH_SIZE:
            self.save_used_hash(board, uct_node, n_idx)

        self.enough_size = True

    def get_usage_rate(self):
        return self.used / UCT_HASH_SIZE


class UctNode:
    def __init__(self):
        self.evaled = False  # 評価済みフラグ
        self.move_count = 0  # ノードの訪問回数
        self.value = 0  # ノードの価値ネットワークの評価(予測勝率)
        self.policy = None  # 正規化した方策ネットワークの出力(子ノード分の長さを持つ)
        self.child_moves = None  # 子ノードの指し手
        self.child_n_indices = None  # 子ノードのインデックス
        self.child_moves_count = None  # 子ノードの訪問回数 (UCB用)
        self.child_value_sum = None  # 子ノードのvalueの合計 (UCB用)

    def reset(self):
        self.evaled = False
        self.move_count = 0
        self.value = 0
        self.policy = None
        self.child_moves = None
        self.child_n_indices = None
        self.child_moves_count = None
        self.child_value_sum = None


================================================
FILE: src/utils/__init__.py
================================================


================================================
FILE: src/utils/hcpe.py
================================================
import cshogi
import numpy as np
from tqdm import tqdm

from src.features.policy_value import get_policy_value_label


def get_data_from_hcpe(hcpes):
    data = []
    for hcpe in tqdm(hcpes):
        data.append(get_policy_value_label(hcpe))

    data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')])
    return data


def load_hcpes(hcpe_paths):
    hcpe_list = []
    for hcpe_path in hcpe_paths:
        hcpe_list.append(np.fromfile(hcpe_path, dtype=cshogi.HuffmanCodedPosAndEval))
    hcpes = np.concatenate(hcpe_list)
    return hcpes


================================================
FILE: src/utils/misc.py
================================================
import numpy as np


def greedy(logits):
    return np.asarray(logits).argmax()


def boltzmann(logits, temperature):
    logits /= temperature
    logits -= logits.max()
    probabilities = np.exp(logits)
    probabilities /= probabilities.sum()
    return probabilities



================================================
FILE: src/utils/sfen.py
================================================
from pathlib import Path

import numpy as np

from src.features.policy_value import get_moves_from_lines, get_policy_value_label_from_moves


def get_gokaku_sfen_paths(base_dir: Path, max_num_of_moves=40):
    assert base_dir.name == 'ShogiGokakuKyokumen'
    sfen_paths = []
    # TODO: 正規表現にする
    for sfen_path in sorted(list(base_dir.glob('*/*/*.sfen')) + list(base_dir.glob('*/*/*/*.sfen'))):
        # 互角で40手以下の棋譜
        if '互角' in sfen_path.stem and (int(sfen_path.stem.split('手目')[0][-3:]) <= max_num_of_moves):
            sfen_paths.append(sfen_path)
    return sfen_paths


def get_data_from_sfen(sfens):
    # sfen形式のデータは棋譜の途中を使って学習データを作成する
    data = []
    for sfen in sfens:
        moves = get_moves_from_lines(sfen)
        # 手数が6手以上のものを使う
        if len(moves) < 6:
            continue
        # 30手なら6局面を抽出
        for _ in range(max(1, len(moves) // 5)):
            dt = get_policy_value_label_from_moves(moves)
            data.append(dt)

    data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')])
    return data


def load_sfen(sfen_path):
    with open(sfen_path) as f:
        sfens = [l.rstrip() for l in f.readlines()]
    return sfens


def load_sfens(sfen_paths):
    sfens = []
    for kifu_path in sfen_paths:
        sfens.extend(load_sfen(kifu_path))
    return sfens


================================================
FILE: src/utils/shogi.py
================================================
import numpy as np

# 移動の定数
MOVE_DIRECTION = [
    UP, UP_LEFT, UP_RIGHT, LEFT, RIGHT, DOWN, DOWN_LEFT, DOWN_RIGHT, UP2_LEFT, UP2_RIGHT,
    UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE,
    DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE
] = range(20)

# 成り変換テーブル
MOVE_DIRECTION_PROMOTED = [
    UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE,
    DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE
]

# 指し手を表すラベルの数
MOVE_DIRECTION_LABEL_NUM = len(MOVE_DIRECTION) + 7  # 7は持ち駒の種類

# 先手駒と後手駒を逆に変換する辞書(reverse_piece_fn)の用意
bw_dict = {k: k + 16 for k in range(1, 15)}  # 1~14に自駒, 17~30に敵駒
wb_dict = {v: k for k, v in bw_dict.items()}
pieces_dict = {**bw_dict, **wb_dict, 0: 0}  # 0は空きマス
pieces_list = list(pieces_dict.keys())
reverse_piece_fn = np.vectorize(pieces_dict.get)


================================================
FILE: tools/download_and_build_lesserkai.sh
================================================
wget http://shogidokoro.starfree.jp/download/LesserkaiSrc.zip -P ./work_dirs
unzip ./work_dirs/LesserkaiSrc.zip -d ./work_dirs/
rm ./work_dirs/LesserkaiSrc.zip
cd ./work_dirs/LesserkaiSrc/Lesserkai
make

================================================
FILE: tools/make_dataset.py
================================================
from pathlib import Path

import numpy as np
from sklearn.model_selection import train_test_split

from src.utils.hcpe import get_data_from_hcpe, load_hcpes
from src.utils.sfen import get_data_from_sfen, get_gokaku_sfen_paths, load_sfens


def main():
    gokaku_dir = Path('./data/ShogiGokakuKyokumen/')
    hcpe_dir = Path('./data/hcpe/')
    dataset_base_dir = Path('./data/dataset/')
    dataset_base_dir.mkdir(exist_ok=True)

    for num_of_moves in [40, 100]:
        sfen_paths = get_gokaku_sfen_paths(gokaku_dir, num_of_moves)
        sfens = load_sfens(sfen_paths)
        data = get_data_from_sfen(sfens)
        train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)

        dataset_dir = dataset_base_dir / f'gokaku_{num_of_moves:03d}'
        dataset_dir.mkdir()

        np.save(dataset_dir / 'train.npy', train_data)
        np.save(dataset_dir / 'val.npy', valid_data)

    # selfplayの棋譜
    hcpe_paths = sorted(hcpe_dir.glob('selfplay-*'))
    hcpes = load_hcpes(hcpe_paths)
    data = get_data_from_hcpe(hcpes)

    train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)

    dataset_dir = dataset_base_dir / 'selfplay'
    dataset_dir.mkdir()

    np.save(dataset_dir / 'train.npy', train_data)
    np.save(dataset_dir / 'val.npy', valid_data)


if __name__ == '__main__':
    main()


================================================
FILE: tools/pl_to_transformers.py
================================================
from argparse import ArgumentParser
from pathlib import Path

import torch

from src.model.bert import config


def argparse():
    parser = ArgumentParser(description='Convert pl checkpoint to transformers format')
    parser.add_argument('ckpt_path', type=str)
    args, _ = parser.parse_known_args()
    return args


def main(args):
    ckpt_path = Path(args.ckpt_path)
    ckpt_dir = ckpt_path.parent

    ckpt = torch.load(ckpt_path)
    state_dict = ckpt['state_dict']
    state_dict = {'.'.join(k.split('.')[2:]): v for k, v in state_dict.items()}

    # 同一ディレクトリにpytorch_model.binとconfig.jsonが必要
    state_dict_path = ckpt_dir / f'pytorch_model.bin'
    torch.save(state_dict, state_dict_path)
    config.to_json_file(ckpt_dir / 'config.json')


if __name__ == '__main__':
    args = argparse()
    main(args)


================================================
FILE: tools/test_engine.py
================================================
from argparse import ArgumentParser
from pathlib import Path

from cshogi import cli


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('engine_path')
    return parser.parse_args()


def main(args):
    benchmark = str(Path('./work_dirs/LesserkaiSrc/Lesserkai/Lesserkai').resolve())
    engine = str((Path(args.engine_path)).resolve())

    print('benchmark: ', benchmark)
    print('engine: ', engine)

    cli.main(benchmark, engine)


if __name__ == '__main__':
    args = parse_args()
    main(args)


================================================
FILE: tools/train.py
================================================
from argparse import ArgumentParser
from pathlib import Path

import pytorch_lightning as pl
from omegaconf import OmegaConf
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

from src.pl_modules import get_pl_modules


def argparse():
    parser = ArgumentParser()
    parser.add_argument('cfg', type=str)
    parser.add_argument('--log_dir', type=str, default='./work_dirs')
    parser.add_argument('--ckpt_path', type=str, default='')
    parser = pl.Trainer.add_argparse_args(parser)
    args, _ = parser.parse_known_args()
    return args


def main(args):
    cfg = OmegaConf.load(args.cfg)
    pl.seed_everything(cfg.seed)

    # configのtrain_paramsをargsに反映
    for k, v in cfg.train_params.items():
        setattr(args, k, v)

    # Disable default checkpoint callback
    args.checkpoint_callback = False
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.logger = pl_loggers.TensorBoardLogger(save_dir=args.log_dir, name=Path(args.cfg).stem, default_hp_metric=False)
    trainer.callbacks.append(ModelCheckpoint(filename='{step:07d}-{val_loss:.2f}', monitor='val_loss',
                                             save_top_k=1, save_last=True))

    model, data = get_pl_modules(cfg)
    if args.ckpt_path:
        model = model.load_from_checkpoint(args.ckpt_path)
    trainer.fit(model, datamodule=data)


if __name__ == '__main__':
    args = argparse()
    main(args)
Download .txt
gitextract_4mn9ams8/

├── .gitignore
├── Makefile
├── README.md
├── configs/
│   ├── .gitkeep
│   ├── mlm_base.yaml
│   └── policy_value_base.yaml
├── docker/
│   ├── Dockerfile
│   └── entrypoint.sh
├── engine/
│   ├── mcts_player.sh
│   └── policy_player.sh
├── env_name.yml
├── requirements.txt
├── setup.py
├── src/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── mlm.py
│   │   └── policy_value.py
│   ├── features/
│   │   ├── __init__.py
│   │   ├── common.py
│   │   └── policy_value.py
│   ├── model/
│   │   ├── __init__.py
│   │   └── bert.py
│   ├── pl_modules/
│   │   ├── __init__.py
│   │   ├── mlm.py
│   │   └── policy_value.py
│   ├── player/
│   │   ├── __init__.py
│   │   ├── base_player.py
│   │   ├── mcts_player.py
│   │   ├── policy_player.py
│   │   └── usi.py
│   ├── uct/
│   │   ├── __init__.py
│   │   └── uct_node.py
│   └── utils/
│       ├── __init__.py
│       ├── hcpe.py
│       ├── misc.py
│       ├── sfen.py
│       └── shogi.py
└── tools/
    ├── download_and_build_lesserkai.sh
    ├── make_dataset.py
    ├── pl_to_transformers.py
    ├── test_engine.py
    └── train.py
Download .txt
SYMBOL INDEX (104 symbols across 20 files)

FILE: src/data/mlm.py
  class MLMDataset (line 9) | class MLMDataset(Dataset):
    method __init__ (line 10) | def __init__(self, data):
    method __len__ (line 15) | def __len__(self):
    method __getitem__ (line 18) | def __getitem__(self, idx):

FILE: src/data/policy_value.py
  class PolicyValueDataset (line 5) | class PolicyValueDataset(Dataset):
    method __init__ (line 6) | def __init__(self, data):
    method __len__ (line 9) | def __len__(self):
    method __getitem__ (line 12) | def __getitem__(self, idx):

FILE: src/features/common.py
  function get_seq_from_board (line 6) | def get_seq_from_board(board):

FILE: src/features/policy_value.py
  function get_move_label (line 12) | def get_move_label(move, color):
  function get_result (line 64) | def get_result(result, color):
  function get_policy_value_label (line 75) | def get_policy_value_label(hcpe):
  function get_policy_value_label_from_moves (line 87) | def get_policy_value_label_from_moves(moves):
  function get_moves_from_lines (line 99) | def get_moves_from_lines(line):

FILE: src/model/bert.py
  class BertMLM (line 22) | class BertMLM(nn.Module):
    method __init__ (line 23) | def __init__(self, model_dir=None):
    method forward (line 30) | def forward(self, input_ids, labels):
  class BertPolicyValue (line 34) | class BertPolicyValue(nn.Module):
    method __init__ (line 35) | def __init__(self, model_dir=None):
    method forward (line 58) | def forward(self, input_ids, labels=None):

FILE: src/pl_modules/__init__.py
  function get_pl_modules (line 5) | def get_pl_modules(cfg):

FILE: src/pl_modules/mlm.py
  class MLMModule (line 12) | class MLMModule(pl.LightningModule):
    method __init__ (line 14) | def __init__(self, hparams):
    method forward (line 19) | def forward(self, batch):
    method training_step (line 24) | def training_step(self, batch, batch_idx):
    method validation_step (line 30) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 35) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 40) | def configure_optimizers(self):
  class MLMDataModule (line 44) | class MLMDataModule(pl.LightningDataModule):
    method __init__ (line 45) | def __init__(self, cfg):
    method setup (line 49) | def setup(self, stage=None):
    method train_dataloader (line 56) | def train_dataloader(self):
    method val_dataloader (line 59) | def val_dataloader(self):

FILE: src/pl_modules/policy_value.py
  class PolicyValueModule (line 12) | class PolicyValueModule(pl.LightningModule):
    method __init__ (line 13) | def __init__(self, hparams):
    method forward (line 18) | def forward(self, input_ids, labels=None):
    method training_step (line 22) | def training_step(self, batch, batch_idx):
    method validation_step (line 27) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 34) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 42) | def configure_optimizers(self):
  class PolicyValueDataModule (line 46) | class PolicyValueDataModule(pl.LightningDataModule):
    method __init__ (line 47) | def __init__(self, cfg):
    method setup (line 51) | def setup(self, stage=None):
    method train_dataloader (line 60) | def train_dataloader(self):
    method val_dataloader (line 63) | def val_dataloader(self):

FILE: src/player/base_player.py
  class BasePlayer (line 4) | class BasePlayer:
    method __init__ (line 5) | def __init__(self):
    method usi (line 8) | def usi(self):
    method usinewgame (line 12) | def usinewgame(self):
    method setoption (line 15) | def setoption(self, option):
    method isready (line 18) | def isready(self):
    method position (line 21) | def position(self, moves):
    method go (line 31) | def go(self):
    method quit (line 34) | def quit(self):

FILE: src/player/mcts_player.py
  class MCTSPlayer (line 18) | class MCTSPlayer(BasePlayer):
    method __init__ (line 19) | def __init__(self, ckpt_path, playout_halt=1000, temperature=1, resign...
    method usi (line 33) | def usi(self):
    method isready (line 37) | def isready(self):
    method go (line 45) | def go(self):
    method expand_node (line 123) | def expand_node(self):
    method eval_node (line 152) | def eval_node(self, n_idx):
    method uct_search (line 182) | def uct_search(self, n_idx):
    method select_max_ucb_child (line 221) | def select_max_ucb_child(self, c_idx):
    method interruption_check (line 235) | def interruption_check(self):
  function argparse (line 249) | def argparse():
  function main (line 258) | def main(args):

FILE: src/player/policy_player.py
  class PolicyPlayer (line 17) | class PolicyPlayer(BasePlayer):
    method __init__ (line 18) | def __init__(self, ckpt_path):
    method usi (line 23) | def usi(self):
    method isready (line 27) | def isready(self):
    method go (line 34) | def go(self):
  function argparse (line 72) | def argparse():
  function main (line 81) | def main(args):

FILE: src/player/usi.py
  function usi (line 4) | def usi(player: BasePlayer):

FILE: src/uct/uct_node.py
  function hash_to_index (line 8) | def hash_to_index(zhash):
  class NodeHashEntry (line 12) | class NodeHashEntry:
    method __init__ (line 13) | def __init__(self):
    method reset (line 19) | def reset(self):
  class NodeHash (line 26) | class NodeHash:
    method __init__ (line 27) | def __init__(self):
    method initialize (line 32) | def initialize(self):
    method search_empty_index (line 43) | def search_empty_index(self, zhash, color, moves):
    method find_same_hash_index (line 65) | def find_same_hash_index(self, zhash, color, moves):
    method save_used_hash (line 78) | def save_used_hash(self, board, uct_nodes, n_idx):
    method delete_old_hash (line 93) | def delete_old_hash(self, board, uct_node):
    method get_usage_rate (line 106) | def get_usage_rate(self):
  class UctNode (line 110) | class UctNode:
    method __init__ (line 111) | def __init__(self):
    method reset (line 121) | def reset(self):

FILE: src/utils/hcpe.py
  function get_data_from_hcpe (line 8) | def get_data_from_hcpe(hcpes):
  function load_hcpes (line 17) | def load_hcpes(hcpe_paths):

FILE: src/utils/misc.py
  function greedy (line 4) | def greedy(logits):
  function boltzmann (line 8) | def boltzmann(logits, temperature):

FILE: src/utils/sfen.py
  function get_gokaku_sfen_paths (line 8) | def get_gokaku_sfen_paths(base_dir: Path, max_num_of_moves=40):
  function get_data_from_sfen (line 19) | def get_data_from_sfen(sfens):
  function load_sfen (line 36) | def load_sfen(sfen_path):
  function load_sfens (line 42) | def load_sfens(sfen_paths):

FILE: tools/make_dataset.py
  function main (line 10) | def main():

FILE: tools/pl_to_transformers.py
  function argparse (line 9) | def argparse():
  function main (line 16) | def main(args):

FILE: tools/test_engine.py
  function parse_args (line 7) | def parse_args():
  function main (line 13) | def main(args):

FILE: tools/train.py
  function argparse (line 12) | def argparse():
  function main (line 22) | def main(args):
Condensed preview — 41 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (61K chars).
[
  {
    "path": ".gitignore",
    "chars": 2012,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "Makefile",
    "chars": 806,
    "preview": "PROJECT ?= bert-mcts\nDATADIR ?= ${PWD}/data\nWORKSPACE ?= /workspace/$(PROJECT)\nDOCKER_IMAGE ?= ${PROJECT}:latest\n\nSHMSIZ"
  },
  {
    "path": "README.md",
    "chars": 2175,
    "preview": "# BERT-MCTS-YOUTUBE\n\nYouTubeにてヨビノリたくみさんと対戦した将棋ソフトです。\n自然言語モデルであるBERTとモンテカルロ木探索(MCTS)の組み合わせで出来ています。  \nすべてpythonで書いてあるため、探索"
  },
  {
    "path": "configs/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "configs/mlm_base.yaml",
    "chars": 407,
    "preview": "model_type: 'MLM'\n\nseed: 42\ndataset_dir: './data/dataset/gokaku_100'\nmodel_dir:\n\ntrain_loader:\n  batch_size: 64\n  shuffl"
  },
  {
    "path": "configs/policy_value_base.yaml",
    "chars": 485,
    "preview": "model_type: 'PolicyValue'\n\nseed: 42\ndataset_dir: './data/dataset/selfplay'\nmodel_dir: './work_dirs/mlm_base/version_0/ch"
  },
  {
    "path": "docker/Dockerfile",
    "chars": 226,
    "preview": "FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime\nENV DEBIAN_FRONTEND=noninteractive\nADD requirements.txt /tmp\nRUN pip "
  },
  {
    "path": "docker/entrypoint.sh",
    "chars": 34,
    "preview": "python setup.py develop\nexec \"$@\"\n"
  },
  {
    "path": "engine/mcts_player.sh",
    "chars": 88,
    "preview": "#!/bin/sh\npython -m src.player.mcts_player --ckpt_path ./work_dirs/youtube_version.ckpt\n"
  },
  {
    "path": "engine/policy_player.sh",
    "chars": 90,
    "preview": "#!/bin/sh\npython -m src.player.policy_player --ckpt_path ./work_dirs/youtube_version.ckpt\n"
  },
  {
    "path": "env_name.yml",
    "chars": 10251,
    "preview": "name: bert-mcts-youtube\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - alabaster=0.7.12"
  },
  {
    "path": "requirements.txt",
    "chars": 134,
    "preview": "--find-links https://download.pytorch.org/whl/torch_stable.html\ntorch>=1.6\npytorch-lightning==1.2.7\ntransformers>=4.5\ncs"
  },
  {
    "path": "setup.py",
    "chars": 390,
    "preview": "from setuptools import setup, find_packages\nfrom torch.utils.cpp_extension import BuildExtension\n\nif __name__ == '__main"
  },
  {
    "path": "src/data/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "src/data/mlm.py",
    "chars": 1269,
    "preview": "import numpy as np\r\nimport torch\r\nfrom torch.utils.data import Dataset\r\n\r\nfrom src.utils.shogi import pieces_list\r\n\r\n\r\n#"
  },
  {
    "path": "src/data/policy_value.py",
    "chars": 604,
    "preview": "import torch\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass PolicyValueDataset(Dataset):\r\n    def __init__(self, data)"
  },
  {
    "path": "src/features/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "src/features/common.py",
    "chars": 334,
    "preview": "import cshogi\n\nfrom src.utils.shogi import reverse_piece_fn\n\n\ndef get_seq_from_board(board):\n    bp, wp = board.pieces_i"
  },
  {
    "path": "src/features/policy_value.py",
    "chars": 3308,
    "preview": "import cshogi\r\nimport numpy as np\r\nfrom cshogi import move_drop_hand_piece, move_from, move_is_drop, move_is_promotion, "
  },
  {
    "path": "src/model/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "src/model/bert.py",
    "chars": 2378,
    "preview": "import torch.nn as nn\r\nfrom transformers import BertConfig, BertForMaskedLM, BertModel\r\n\r\nfrom src.utils.shogi import pi"
  },
  {
    "path": "src/pl_modules/__init__.py",
    "chars": 370,
    "preview": "from .mlm import MLMModule, MLMDataModule\nfrom .policy_value import PolicyValueModule, PolicyValueDataModule\n\n\ndef get_p"
  },
  {
    "path": "src/pl_modules/mlm.py",
    "chars": 1817,
    "preview": "from pathlib import Path\n\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoader\nfrom"
  },
  {
    "path": "src/pl_modules/policy_value.py",
    "chars": 2203,
    "preview": "from pathlib import Path\n\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoader\nfrom"
  },
  {
    "path": "src/player/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "src/player/base_player.py",
    "chars": 678,
    "preview": "import cshogi\n\n\nclass BasePlayer:\n    def __init__(self):\n        self.board = cshogi.Board()\n\n    def usi(self):\n      "
  },
  {
    "path": "src/player/mcts_player.py",
    "chars": 8873,
    "preview": "import time\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nimport cshogi\nimport numpy as np\nimport torch\n"
  },
  {
    "path": "src/player/policy_player.py",
    "chars": 2485,
    "preview": "from argparse import ArgumentParser\nfrom pathlib import Path\nimport numpy as np\n\nimport cshogi\nimport torch\nimport torch"
  },
  {
    "path": "src/player/usi.py",
    "chars": 731,
    "preview": "from src.player.base_player import BasePlayer\n\n\ndef usi(player: BasePlayer):\n    while True:\n        cmd_line = input()\n"
  },
  {
    "path": "src/uct/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "src/uct/uct_node.py",
    "chars": 3881,
    "preview": "# ノードの上限値\nUCT_HASH_SIZE = 4096\n# 未展開のノードのインデックス\nNOT_EXPANDED = -1\n\n\n# ゾブリストハッシュ値をUCT_HASH_SIZEに圧縮\ndef hash_to_index(zhas"
  },
  {
    "path": "src/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/utils/hcpe.py",
    "chars": 583,
    "preview": "import cshogi\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom src.features.policy_value import get_policy_value_label\n\n\nd"
  },
  {
    "path": "src/utils/misc.py",
    "chars": 273,
    "preview": "import numpy as np\n\n\ndef greedy(logits):\n    return np.asarray(logits).argmax()\n\n\ndef boltzmann(logits, temperature):\n  "
  },
  {
    "path": "src/utils/sfen.py",
    "chars": 1346,
    "preview": "from pathlib import Path\n\nimport numpy as np\n\nfrom src.features.policy_value import get_moves_from_lines, get_policy_val"
  },
  {
    "path": "src/utils/shogi.py",
    "chars": 933,
    "preview": "import numpy as np\r\n\r\n# 移動の定数\r\nMOVE_DIRECTION = [\r\n    UP, UP_LEFT, UP_RIGHT, LEFT, RIGHT, DOWN, DOWN_LEFT, DOWN_RIGHT, "
  },
  {
    "path": "tools/download_and_build_lesserkai.sh",
    "chars": 202,
    "preview": "wget http://shogidokoro.starfree.jp/download/LesserkaiSrc.zip -P ./work_dirs\nunzip ./work_dirs/LesserkaiSrc.zip -d ./wor"
  },
  {
    "path": "tools/make_dataset.py",
    "chars": 1354,
    "preview": "from pathlib import Path\n\nimport numpy as np\nfrom sklearn.model_selection import train_test_split\n\nfrom src.utils.hcpe i"
  },
  {
    "path": "tools/pl_to_transformers.py",
    "chars": 819,
    "preview": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nimport torch\n\nfrom src.model.bert import config\n\n\ndef argp"
  },
  {
    "path": "tools/test_engine.py",
    "chars": 528,
    "preview": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nfrom cshogi import cli\n\n\ndef parse_args():\n    parser = Ar"
  },
  {
    "path": "tools/train.py",
    "chars": 1453,
    "preview": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nimport pytorch_lightning as pl\nfrom omegaconf import Omega"
  }
]

About this extraction

This page contains the full source code of the nyoki-mtl/bert-mcts-youtube GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 41 files (52.3 KB), approximately 19.2k tokens, and a symbol index with 104 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!