Repository: nyoki-mtl/bert-mcts-youtube Branch: main Commit: a12e0bdaf313 Files: 41 Total size: 52.3 KB Directory structure: gitextract_4mn9ams8/ ├── .gitignore ├── Makefile ├── README.md ├── configs/ │ ├── .gitkeep │ ├── mlm_base.yaml │ └── policy_value_base.yaml ├── docker/ │ ├── Dockerfile │ └── entrypoint.sh ├── engine/ │ ├── mcts_player.sh │ └── policy_player.sh ├── env_name.yml ├── requirements.txt ├── setup.py ├── src/ │ ├── data/ │ │ ├── __init__.py │ │ ├── mlm.py │ │ └── policy_value.py │ ├── features/ │ │ ├── __init__.py │ │ ├── common.py │ │ └── policy_value.py │ ├── model/ │ │ ├── __init__.py │ │ └── bert.py │ ├── pl_modules/ │ │ ├── __init__.py │ │ ├── mlm.py │ │ └── policy_value.py │ ├── player/ │ │ ├── __init__.py │ │ ├── base_player.py │ │ ├── mcts_player.py │ │ ├── policy_player.py │ │ └── usi.py │ ├── uct/ │ │ ├── __init__.py │ │ └── uct_node.py │ └── utils/ │ ├── __init__.py │ ├── hcpe.py │ ├── misc.py │ ├── sfen.py │ └── shogi.py └── tools/ ├── download_and_build_lesserkai.sh ├── make_dataset.py ├── pl_to_transformers.py ├── test_engine.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # Project specific folders lightning_logs/ tf_logs/ remote_logs/ local_logs/ *_logs/ checkpoints/ dataset/ local_misc/ notebooks/data/ # LSF logfiles lsf.* # IDEA folders .idea/ /work_dirs/* /data/* !.gitkeep ================================================ FILE: Makefile ================================================ PROJECT ?= bert-mcts DATADIR ?= ${PWD}/data WORKSPACE ?= /workspace/$(PROJECT) DOCKER_IMAGE ?= ${PROJECT}:latest SHMSIZE ?= 100G DOCKER_OPTS := \ --name ${PROJECT} \ --rm -it \ --shm-size=${SHMSIZE} \ -v ${PWD}:${WORKSPACE} \ -v ${DATADIR}:${WORKSPACE}/data \ -v ${LOG_DIR}:${WORKSPACE}/work_dirs/logs \ -w ${WORKSPACE} \ --ipc=host \ --network=host \ --gpus all docker-build: docker build -f docker/Dockerfile -t ${DOCKER_IMAGE} . docker-start-interactive: docker-build docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash docker-start-jupyter: docker-build docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ bash -c "jupyter lab --port=8888 --ip=0.0.0.0 --allow-root --no-browser" docker-run: docker-build docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ bash -c "${COMMAND}" ================================================ FILE: README.md ================================================ # BERT-MCTS-YOUTUBE YouTubeにてヨビノリたくみさんと対戦した将棋ソフトです。 自然言語モデルであるBERTとモンテカルロ木探索(MCTS)の組み合わせで出来ています。 すべてpythonで書いてあるため、探索の速度は遅いです。 BERT以外の大部分は『将棋AIで学ぶディープラーニング』を参考に書いています。 - [書籍(amazon)](https://www.amazon.co.jp/dp/B07B7JJ929) - [github](https://github.com/TadaoYamaoka/python-dlshogi) ## 環境 ### Colab テストするだけなら[google colab](https://colab.research.google.com/drive/10KAuLlNe6FKZBp_iE2bQJPNhoY2WeACx?usp=sharing) が簡単です。 以下はローカルで試す場合。CPUだと遅いのでCUDA環境が望ましいです。 重みファイルは[ここ](https://drive.google.com/drive/folders/1N-Np2NmNLtLGS9gjnreBkYdTxrH1EHFw?usp=sharing) にアップしてあり、 たくみさんと戦った重みファイルがyoutube_version.ckpt、追加で数日間学習させた重みファイルがlatest.ckptになります。 ダウンロード先のパスはengine/***_player.sh内で指定してください。 デフォルトではwork_dirs以下にダウンロードすることを想定しています。 ### Docker cuda10.2以上のnvidia-dockerが整っているなら次のコマンドで環境に入れます。 ```bash $ make docker-start-interactive ``` ### Ubuntu18.04 cuda10.2でanacondaが入っていれば次のコマンドで仮想環境を作れます。 ```bash $ conda env create -f env_name.yml $ conda activate bert-mcts-youtube $ python setup.py develop ``` ### Windows10 未検証 ## 将棋エンジンのテスト エンジンはengineディレクトリに用意しています。これらはShogiGUIなどから呼び出すことができます。 - policy_player.shはBERTの出力する方策のみを頼りに指すモデル(弱い) - mcts_player.shはBERTの出力をもとにMCTSで探索するモデル ## 学習 学習には互角局面集とGCTの自己対戦棋譜を用いました。 モデルはMasked Language Modelで事前学習してから、Policy Value Networkの学習という手順を踏みます。 ただし、将棋は良質な教師データが大量にあるため事前学習の効果はあまりない気がします。 ### データの準備 互角局面集のダウンロード ```bash $ cd data $ git clone https://github.com/tttak/ShogiGokakuKyokumen.git ``` GCTの自己対戦棋譜 ```bash $ cd data $ mkdir hcpe ``` GCTの自己対戦棋譜は開発者の加納さんが[リンク](https://drive.google.com/drive/folders/14FaqqIHRctTQIY6hScCFXWQQZ_pSU3-F) に公開してくださっていまし。 ここからselfplay-***となっているファイルをいくつかdata/hcpe以下にダウンロードしてください。 サイズが大きいので一個でも十分な量あります。 これらを準備できたら以下のコマンドでデータセットを作ります。 ```bash $ python tools/make_dataset.py ``` ### Masked Language Modelの学習 ```bash $ python tools/train.py configs/mlm_base.yaml ``` ### 重みファイルの変換 Masked Language Modelのチェックポイントをtransformers形式に変換しておきます。 これによって転移学習のコードが多少書きやすくなります。 ```bash $ python tools/pl_to_transformers.py work_dirs/mlm_base/version_0/checkpoints/last.ckpt ``` ### Policy Value Modelの学習 最後にこれらを使ってPolicy Valueを学習させます。 ```bash $ python tools/train.py configs/policy_value.yaml ``` ================================================ FILE: configs/.gitkeep ================================================ ================================================ FILE: configs/mlm_base.yaml ================================================ model_type: 'MLM' seed: 42 dataset_dir: './data/dataset/gokaku_100' model_dir: train_loader: batch_size: 64 shuffle: True num_workers: 8 pin_memory: False drop_last: True val_loader: batch_size: 64 shuffle: False num_workers: 8 pin_memory: False drop_last: False train_params: max_epochs: 5 # validationおよびcheckpointの間隔step数 val_check_interval: 3000 # 環境に応じて,適宜変更 gpus: [0] ================================================ FILE: configs/policy_value_base.yaml ================================================ model_type: 'PolicyValue' seed: 42 dataset_dir: './data/dataset/selfplay' model_dir: './work_dirs/mlm_base/version_0/checkpoints' train_loader: batch_size: 128 shuffle: True num_workers: 0 pin_memory: True drop_last: True val_loader: batch_size: 128 shuffle: False num_workers: 0 pin_memory: True drop_last: False train_params: max_epochs: 1 # validationおよびcheckpointの間隔step数 val_check_interval: 30000 limit_val_batches: 0.1 # 環境に応じて,適宜変更 gpus: [0] ================================================ FILE: docker/Dockerfile ================================================ FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime ENV DEBIAN_FRONTEND=noninteractive ADD requirements.txt /tmp RUN pip install -r /tmp/requirements.txt ADD docker/entrypoint.sh /tmp ENTRYPOINT ["bash", "/tmp/entrypoint.sh"] ================================================ FILE: docker/entrypoint.sh ================================================ python setup.py develop exec "$@" ================================================ FILE: engine/mcts_player.sh ================================================ #!/bin/sh python -m src.player.mcts_player --ckpt_path ./work_dirs/youtube_version.ckpt ================================================ FILE: engine/policy_player.sh ================================================ #!/bin/sh python -m src.player.policy_player --ckpt_path ./work_dirs/youtube_version.ckpt ================================================ FILE: env_name.yml ================================================ name: bert-mcts-youtube channels: - pytorch - defaults dependencies: - _libgcc_mutex=0.1=main - alabaster=0.7.12=py37_0 - anaconda=2020.11=py37_0 - anaconda-client=1.7.2=py37_0 - anaconda-project=0.8.4=py_0 - argh=0.26.2=py37_0 - argon2-cffi=20.1.0=py37h7b6447c_1 - asn1crypto=1.4.0=py_0 - astroid=2.4.2=py37_0 - astropy=4.0.2=py37h7b6447c_0 - async_generator=1.10=py37h28b3542_0 - atomicwrites=1.4.0=py_0 - attrs=20.3.0=pyhd3eb1b0_0 - autopep8=1.5.4=py_0 - babel=2.8.1=pyhd3eb1b0_0 - backcall=0.2.0=py_0 - backports=1.0=py_2 - backports.shutil_get_terminal_size=1.0.0=py37_2 - beautifulsoup4=4.9.3=pyhb0f4dca_0 - bitarray=1.6.1=py37h27cfd23_0 - bkcharts=0.2=py37_0 - blas=1.0=mkl - bleach=3.2.1=py_0 - blosc=1.20.1=hd408876_0 - bokeh=2.2.3=py37_0 - boto=2.49.0=py37_0 - bottleneck=1.3.2=py37heb32a55_1 - brotlipy=0.7.0=py37h7b6447c_1000 - bzip2=1.0.8=h7b6447c_0 - ca-certificates=2020.10.14=0 - cairo=1.14.12=h8948797_3 - certifi=2020.6.20=pyhd3eb1b0_3 - cffi=1.14.3=py37he30daa8_0 - chardet=3.0.4=py37_1003 - click=7.1.2=py_0 - cloudpickle=1.6.0=py_0 - clyent=1.2.2=py37_1 - colorama=0.4.4=py_0 - contextlib2=0.6.0.post1=py_0 - cryptography=3.1.1=py37h1ba5d50_0 - cudatoolkit=10.2.89=hfd86e86_1 - curl=7.71.1=hbc83047_1 - cycler=0.10.0=py37_0 - cython=0.29.21=py37he6710b0_0 - cytoolz=0.11.0=py37h7b6447c_0 - dask=2.30.0=py_0 - dask-core=2.30.0=py_0 - dbus=1.13.18=hb2f20db_0 - decorator=4.4.2=py_0 - defusedxml=0.6.0=py_0 - diff-match-patch=20200713=py_0 - distributed=2.30.1=py37h06a4308_0 - docutils=0.16=py37_1 - entrypoints=0.3=py37_0 - et_xmlfile=1.0.1=py_1001 - expat=2.2.10=he6710b0_2 - fastcache=1.1.0=py37h7b6447c_0 - ffmpeg=4.3=hf484d3e_0 - filelock=3.0.12=py_0 - flake8=3.8.4=py_0 - flask=1.1.2=py_0 - fontconfig=2.13.0=h9420a91_0 - freetype=2.10.4=h5ab3b9f_0 - fribidi=1.0.10=h7b6447c_0 - fsspec=0.8.3=py_0 - future=0.18.2=py37_1 - get_terminal_size=1.0.0=haa9412d_0 - gevent=20.9.0=py37h7b6447c_0 - glib=2.66.1=h92f7085_0 - glob2=0.7=py_0 - gmp=6.1.2=h6c8ec71_1 - gmpy2=2.0.8=py37h10f8cd9_2 - gnutls=3.6.15=he1e5248_0 - graphite2=1.3.14=h23475e2_0 - greenlet=0.4.17=py37h7b6447c_0 - gst-plugins-base=1.14.0=hbbd80ab_1 - gstreamer=1.14.0=hb31296c_0 - h5py=2.10.0=py37h7918eee_0 - harfbuzz=2.4.0=hca77d97_1 - hdf5=1.10.4=hb1b8bf9_0 - heapdict=1.0.1=py_0 - html5lib=1.1=py_0 - icu=58.2=he6710b0_3 - idna=2.10=py_0 - imageio=2.9.0=py_0 - imagesize=1.2.0=py_0 - importlib-metadata=2.0.0=py_1 - importlib_metadata=2.0.0=1 - iniconfig=1.1.1=py_0 - intel-openmp=2020.2=254 - intervaltree=3.1.0=py_0 - ipykernel=5.3.4=py37h5ca1d4c_0 - ipython=7.19.0=py37hb070fc8_0 - ipython_genutils=0.2.0=py37_0 - ipywidgets=7.5.1=py_1 - isort=5.6.4=py_0 - itsdangerous=1.1.0=py37_0 - jbig=2.1=hdba287a_0 - jdcal=1.4.1=py_0 - jedi=0.17.1=py37_0 - jeepney=0.5.0=pyhd3eb1b0_0 - jinja2=2.11.2=py_0 - joblib=0.17.0=py_0 - jpeg=9b=h024ee3a_2 - json5=0.9.5=py_0 - jsonschema=3.2.0=py_2 - jupyter=1.0.0=py37_7 - jupyter_client=6.1.7=py_0 - jupyter_console=6.2.0=py_0 - jupyter_core=4.6.3=py37_0 - jupyterlab=2.2.6=py_0 - jupyterlab_pygments=0.1.2=py_0 - jupyterlab_server=1.2.0=py_0 - keyring=21.4.0=py37_1 - kiwisolver=1.3.0=py37h2531618_0 - krb5=1.18.2=h173b8e3_0 - lame=3.100=h7b6447c_0 - lazy-object-proxy=1.4.3=py37h7b6447c_0 - lcms2=2.11=h396b838_0 - ld_impl_linux-64=2.33.1=h53a641e_7 - libarchive=3.4.2=h62408e4_0 - libcurl=7.71.1=h20c2e04_1 - libedit=3.1.20191231=h14c3975_1 - libffi=3.3=he6710b0_2 - libgcc-ng=9.1.0=hdf63c60_0 - libgfortran-ng=7.3.0=hdf63c60_0 - libiconv=1.15=h63c8f33_5 - libidn2=2.3.0=h27cfd23_0 - liblief=0.10.1=he6710b0_0 - libllvm10=10.0.1=hbcb73fb_5 - libpng=1.6.37=hbc83047_0 - libsodium=1.0.18=h7b6447c_0 - libspatialindex=1.9.3=he6710b0_0 - libssh2=1.9.0=h1ba5d50_1 - libstdcxx-ng=9.1.0=hdf63c60_0 - libtasn1=4.16.0=h27cfd23_0 - libtiff=4.1.0=h2733197_1 - libtool=2.4.6=h7b6447c_1005 - libunistring=0.9.10=h27cfd23_0 - libuuid=1.0.3=h1bed415_2 - libuv=1.40.0=h7b6447c_0 - libxcb=1.14=h7b6447c_0 - libxml2=2.9.10=hb55368b_3 - libxslt=1.1.34=hc22bd24_0 - llvmlite=0.34.0=py37h269e1b5_4 - locket=0.2.0=py37_1 - lxml=4.6.1=py37hefd8a0e_0 - lz4-c=1.9.2=heb0550a_3 - lzo=2.10=h7b6447c_2 - markupsafe=1.1.1=py37h14c3975_1 - matplotlib=3.3.2=0 - matplotlib-base=3.3.2=py37h817c723_0 - mccabe=0.6.1=py37_1 - mistune=0.8.4=py37h14c3975_1001 - mkl=2020.2=256 - mkl-service=2.3.0=py37he904b0f_0 - mkl_fft=1.2.0=py37h23d657b_0 - mkl_random=1.1.1=py37h0573a6f_0 - mock=4.0.2=py_0 - more-itertools=8.6.0=pyhd3eb1b0_0 - mpc=1.1.0=h10f8cd9_1 - mpfr=4.0.2=hb69a4c5_1 - mpmath=1.1.0=py37_0 - msgpack-python=1.0.0=py37hfd86e86_1 - multipledispatch=0.6.0=py37_0 - nbclient=0.5.1=py_0 - nbconvert=6.0.7=py37_0 - nbformat=5.0.8=py_0 - ncurses=6.2=he6710b0_1 - nest-asyncio=1.4.2=pyhd3eb1b0_0 - nettle=3.7.2=hbbd107a_1 - networkx=2.5=py_0 - ninja=1.10.2=py37hff7bd54_0 - nltk=3.5=py_0 - nose=1.3.7=py37_1004 - notebook=6.1.4=py37_0 - numba=0.51.2=py37h04863e7_1 - numexpr=2.7.1=py37h423224d_0 - numpy=1.19.2=py37h54aff64_0 - numpy-base=1.19.2=py37hfa32c7d_0 - numpydoc=1.1.0=pyhd3eb1b0_1 - olefile=0.46=py37_0 - openh264=2.1.0=hd408876_0 - openpyxl=3.0.5=py_0 - openssl=1.1.1h=h7b6447c_0 - packaging=20.4=py_0 - pandas=1.1.3=py37he6710b0_0 - pandoc=2.11=hb0f4dca_0 - pandocfilters=1.4.3=py37h06a4308_1 - pango=1.45.3=hd140c19_0 - parso=0.7.0=py_0 - partd=1.1.0=py_0 - patchelf=0.12=he6710b0_0 - path=15.0.0=py37_0 - path.py=12.5.0=0 - pathlib2=2.3.5=py37_1 - pathtools=0.1.2=py_1 - patsy=0.5.1=py37_0 - pcre=8.44=he6710b0_0 - pep8=1.7.1=py37_0 - pexpect=4.8.0=py37_1 - pickleshare=0.7.5=py37_1001 - pillow=8.0.1=py37he98fc37_0 - pip=20.2.4=py37h06a4308_0 - pixman=0.40.0=h7b6447c_0 - pkginfo=1.6.1=py37h06a4308_0 - pluggy=0.13.1=py37_0 - ply=3.11=py37_0 - prometheus_client=0.8.0=py_0 - prompt-toolkit=3.0.8=py_0 - prompt_toolkit=3.0.8=0 - psutil=5.7.2=py37h7b6447c_0 - ptyprocess=0.6.0=py37_0 - py=1.9.0=py_0 - py-lief=0.10.1=py37h403a769_0 - pycodestyle=2.6.0=py_0 - pycosat=0.6.3=py37h7b6447c_0 - pycparser=2.20=py_2 - pycrypto=2.6.1=py37h7b6447c_10 - pycurl=7.43.0.6=py37h1ba5d50_0 - pydocstyle=5.1.1=py_0 - pyflakes=2.2.0=py_0 - pygments=2.7.2=pyhd3eb1b0_0 - pylint=2.6.0=py37_0 - pyodbc=4.0.30=py37he6710b0_0 - pyopenssl=19.1.0=py_1 - pyparsing=2.4.7=py_0 - pyqt=5.9.2=py37h05f1152_2 - pyrsistent=0.17.3=py37h7b6447c_0 - pysocks=1.7.1=py37_1 - pytables=3.6.1=py37h71ec239_0 - pytest=6.1.1=py37_0 - python=3.7.9=h7579374_0 - python-dateutil=2.8.1=py_0 - python-jsonrpc-server=0.4.0=py_0 - python-language-server=0.35.1=py_0 - python-libarchive-c=2.9=py_0 - pytorch=1.8.1=py3.7_cuda10.2_cudnn7.6.5_0 - pytz=2020.1=py_0 - pywavelets=1.1.1=py37h7b6447c_2 - pyxdg=0.27=pyhd3eb1b0_0 - pyyaml=5.3.1=py37h7b6447c_1 - pyzmq=19.0.2=py37he6710b0_1 - qdarkstyle=2.8.1=py_0 - qt=5.9.7=h5867ecd_1 - qtawesome=1.0.1=py_0 - qtconsole=4.7.7=py_0 - qtpy=1.9.0=py_0 - readline=8.0=h7b6447c_0 - regex=2020.10.15=py37h7b6447c_0 - requests=2.24.0=py_0 - ripgrep=12.1.1=0 - rope=0.18.0=py_0 - rtree=0.9.4=py37_1 - ruamel_yaml=0.15.87=py37h7b6447c_1 - scikit-image=0.17.2=py37hdf5156a_0 - scikit-learn=0.23.2=py37h0573a6f_0 - scipy=1.5.2=py37h0b6359f_0 - seaborn=0.11.0=py_0 - secretstorage=3.1.2=py37_1 - send2trash=1.5.0=py37_0 - setuptools=50.3.1=py37h06a4308_1 - simplegeneric=0.8.1=py37_2 - singledispatch=3.4.0.3=py_1001 - sip=4.19.8=py37hf484d3e_0 - six=1.15.0=py37h06a4308_0 - snowballstemmer=2.0.0=py_0 - sortedcollections=1.2.1=py_0 - sortedcontainers=2.2.2=py_0 - soupsieve=2.0.1=py_0 - sphinx=3.2.1=py_0 - sphinxcontrib=1.0=py37_1 - sphinxcontrib-applehelp=1.0.2=py_0 - sphinxcontrib-devhelp=1.0.2=py_0 - sphinxcontrib-htmlhelp=1.0.3=py_0 - sphinxcontrib-jsmath=1.0.1=py_0 - sphinxcontrib-qthelp=1.0.3=py_0 - sphinxcontrib-serializinghtml=1.1.4=py_0 - sphinxcontrib-websupport=1.2.4=py_0 - spyder=4.1.5=py37_0 - spyder-kernels=1.9.4=py37_0 - sqlalchemy=1.3.20=py37h27cfd23_0 - sqlite=3.33.0=h62c20be_0 - statsmodels=0.12.0=py37h7b6447c_0 - sympy=1.6.2=py37h06a4308_1 - tbb=2020.3=hfd86e86_0 - tblib=1.7.0=py_0 - terminado=0.9.1=py37_0 - testpath=0.4.4=py_0 - threadpoolctl=2.1.0=pyh5ca1d4c_0 - tifffile=2020.10.1=py37hdd07704_2 - tk=8.6.10=hbc83047_0 - toml=0.10.1=py_0 - toolz=0.11.1=py_0 - torchvision=0.9.1=py37_cu102 - tornado=6.0.4=py37h7b6447c_1 - tqdm=4.50.2=py_0 - traitlets=5.0.5=py_0 - typed-ast=1.4.1=py37h7b6447c_0 - typing_extensions=3.7.4.3=py_0 - ujson=4.0.1=py37he6710b0_0 - unicodecsv=0.14.1=py37_0 - unixodbc=2.3.9=h7b6447c_0 - urllib3=1.25.11=py_0 - watchdog=0.10.3=py37_0 - wcwidth=0.2.5=py_0 - webencodings=0.5.1=py37_1 - werkzeug=1.0.1=py_0 - wheel=0.35.1=py_0 - widgetsnbextension=3.5.1=py37_0 - wrapt=1.11.2=py37h7b6447c_0 - wurlitzer=2.0.1=py37_0 - xlrd=1.2.0=py37_0 - xlsxwriter=1.3.7=py_0 - xlwt=1.3.0=py37_0 - xz=5.2.5=h7b6447c_0 - yaml=0.2.5=h7b6447c_0 - yapf=0.30.0=py_0 - zeromq=4.3.3=he6710b0_3 - zict=2.0.0=py_0 - zipp=3.4.0=pyhd3eb1b0_0 - zlib=1.2.11=h7b6447c_3 - zope=1.0=py37_1 - zope.event=4.5.0=py37_0 - zope.interface=5.1.2=py37h7b6447c_0 - zstd=1.4.5=h9ceee32_0 - pip: - absl-py==0.12.0 - aiohttp==3.7.4.post0 - async-timeout==3.0.1 - cachetools==4.2.1 - cshogi==0.2.4 - google-auth==1.28.1 - google-auth-oauthlib==0.4.4 - grpcio==1.37.0 - markdown==3.3.4 - multidict==5.1.0 - oauthlib==3.1.0 - omegaconf==2.0.6 - protobuf==3.15.8 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pytorch-lightning==1.2.7 - requests-oauthlib==1.3.0 - rsa==4.7.2 - sacremoses==0.0.44 - tensorboard==2.4.1 - tensorboard-plugin-wit==1.8.0 - tokenizers==0.10.2 - torchmetrics==0.2.0 - transformers==4.5.1 - yarl==1.6.3 prefix: /home/charmer/.pyenv/versions/anaconda3-2020.07/envs/bert-mcts ================================================ FILE: requirements.txt ================================================ --find-links https://download.pytorch.org/whl/torch_stable.html torch>=1.6 pytorch-lightning==1.2.7 transformers>=4.5 cshogi omegaconf ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages from torch.utils.cpp_extension import BuildExtension if __name__ == '__main__': setup( name='bert-mcts', description='BERT-driven Shogi AI', author='Hiroki Taniai', author_email='charmer.popopo@gmail.com', packages=find_packages(include=['src']), cmdclass={'build_ext': BuildExtension}, ) ================================================ FILE: src/data/__init__.py ================================================ ================================================ FILE: src/data/mlm.py ================================================ import numpy as np import torch from torch.utils.data import Dataset from src.utils.shogi import pieces_list # Masked Language Model class MLMDataset(Dataset): def __init__(self, data): self.seqs = data['seq'] self.mask_token_id = 32 # 32は駒が割り振られていないid assert self.mask_token_id not in np.array(pieces_list) def __len__(self): return len(self.seqs) def __getitem__(self, idx): inputs = np.array(self.seqs[idx]) labels = inputs.copy() # 予想対象 masked_indices = np.random.random(labels.shape) < 0.15 labels[~masked_indices] = -100 # 80%はマスクトークンに indices_replaced = (np.random.random(labels.shape) < 0.8) & masked_indices inputs[indices_replaced] = self.mask_token_id # 10%はランダムに置き換え indices_random = (np.random.random(labels.shape) < 0.5) & masked_indices & ~indices_replaced random_words = np.random.choice(pieces_list, labels.shape) inputs[indices_random] = random_words[indices_random] # 残り10%はそのままのものが残る ret_dict = {'input_ids': torch.tensor(inputs, dtype=torch.long), 'labels': torch.tensor(labels, dtype=torch.long)} return ret_dict ================================================ FILE: src/data/policy_value.py ================================================ import torch from torch.utils.data import Dataset class PolicyValueDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): dt = self.data[idx] ret_dict = {'input_ids': torch.tensor(dt['seq'], dtype=torch.long), 'labels': torch.tensor(dt['label'], dtype=torch.long), 'values': torch.tensor(dt['value'], dtype=torch.float), 'result': torch.tensor(dt['result'], dtype=torch.long)} return ret_dict ================================================ FILE: src/features/__init__.py ================================================ ================================================ FILE: src/features/common.py ================================================ import cshogi from src.utils.shogi import reverse_piece_fn def get_seq_from_board(board): bp, wp = board.pieces_in_hand if board.turn == cshogi.BLACK: seq = board.pieces + bp + wp else: # 駒の順番を逆にして、駒を先後反転させる。持ち駒は逆にする。 seq = reverse_piece_fn(board.pieces[::-1]).tolist() + wp + bp return seq ================================================ FILE: src/features/policy_value.py ================================================ import cshogi import numpy as np from cshogi import move_drop_hand_piece, move_from, move_is_drop, move_is_promotion, move_to from src.features.common import get_seq_from_board from src.utils.shogi import (DOWN, DOWN_LEFT, DOWN_RIGHT, LEFT, MOVE_DIRECTION, MOVE_DIRECTION_PROMOTED, RIGHT, UP, UP2_LEFT, UP2_RIGHT, UP_LEFT, UP_RIGHT) board = cshogi.Board() def get_move_label(move, color): if not move_is_drop(move): from_sq = move_from(move) to_sq = move_to(move) if color == cshogi.WHITE: to_sq = 80 - to_sq from_sq = 80 - from_sq # file: 筋, rank: 段 from_file, from_rank = divmod(from_sq, 9) to_file, to_rank = divmod(to_sq, 9) dir_file = to_file - from_file dir_rank = to_rank - from_rank if dir_rank < 0 and dir_file == 0: move_direction = UP elif dir_rank == -2 and dir_file == -1: move_direction = UP2_RIGHT elif dir_rank == -2 and dir_file == 1: move_direction = UP2_LEFT elif dir_rank < 0 and dir_file < 0: move_direction = UP_RIGHT elif dir_rank < 0 and dir_file > 0: move_direction = UP_LEFT elif dir_rank == 0 and dir_file < 0: move_direction = RIGHT elif dir_rank == 0 and dir_file > 0: move_direction = LEFT elif dir_rank > 0 and dir_file == 0: move_direction = DOWN elif dir_rank > 0 and dir_file < 0: move_direction = DOWN_RIGHT elif dir_rank > 0 and dir_file > 0: move_direction = DOWN_LEFT else: raise RuntimeError # promote if move_is_promotion(move): move_direction = MOVE_DIRECTION_PROMOTED[move_direction] else: # 持ち駒 move_direction = len(MOVE_DIRECTION) + move_drop_hand_piece(move) - 1 to_sq = move_to(move) if color == cshogi.WHITE: to_sq = 80 - to_sq # labelのmaxは27*81-1=2186 move_label = move_direction * 81 + to_sq return move_label def get_result(result, color): # 引き分け if result == 0: return 0.5 # 手番が勝ち elif ((result == 1) and (color == cshogi.BLACK)) or ((result == 2) and (color == cshogi.WHITE)): return 1 else: return 0 def get_policy_value_label(hcpe): board.reset() board.set_hcp(hcpe['hcp']) seq = get_seq_from_board(board) label = get_move_label(hcpe['bestMove16'], board.turn) value = 1 / (1 + np.exp(-hcpe['eval'] * 0.0013226)) result = get_result(hcpe['gameResult'], board.turn) return seq, label, value, result def get_policy_value_label_from_moves(moves): n = np.random.randint(4, len(moves)-1) board.reset() for move in moves[:n]: board.push(move) seq = get_seq_from_board(board) label = get_move_label(moves[n], board.turn) value = 0.5 result = 0.5 return seq, label, value, result def get_moves_from_lines(line): board.reset() moves = [] for move_usi in line.split()[2:]: move = board.move_from_usi(move_usi) board.push(move) moves.append(move) return moves ================================================ FILE: src/model/__init__.py ================================================ ================================================ FILE: src/model/bert.py ================================================ import torch.nn as nn from transformers import BertConfig, BertForMaskedLM, BertModel from src.utils.shogi import pieces_list config = { 'vocab_size': len(pieces_list) + 4, # MASK_TOKEN_ID, MASK, CLS, SEP 'hidden_size': 768, 'num_hidden_layers': 12, 'num_attention_heads': 12, 'intermediate_size': 3072, # hidden_size * 4が目安 'hidden_act': 'gelu', 'hidden_dropout_prob': 0.1, 'attention_probs_dropout_prob': 0.1, 'max_position_embeddings': 512, # 95(=81(マス目)+7(先手持駒)+7(後手持駒))でいいかも 'type_vocab_size': 1, # 対の文章を入れない。つまりtoken_type_embeddingsは完全に無駄になっている。 'initializer_range': 0.02, } config = BertConfig.from_dict(config) class BertMLM(nn.Module): def __init__(self, model_dir=None): super().__init__() if model_dir is None: self.bert = BertForMaskedLM(config) else: self.bert = BertForMaskedLM.from_pretrained(model_dir) def forward(self, input_ids, labels): return self.bert(input_ids=input_ids, labels=labels) class BertPolicyValue(nn.Module): def __init__(self, model_dir=None): super().__init__() if model_dir is None: self.bert = BertModel(config) else: self.bert = BertModel.from_pretrained(model_dir) self.policy_head = nn.Sequential( nn.Linear(768, 768 * 2), nn.Tanh(), nn.Linear(768 * 2, 9 * 9 * 27) ) self.value_head = nn.Sequential( nn.Linear(768, 768 * 2), nn.Tanh(), nn.Linear(768 * 2, 1), nn.Sigmoid() ) self.loss_policy_fn = nn.CrossEntropyLoss() self.loss_value_fn = nn.MSELoss() def forward(self, input_ids, labels=None): features = self.bert(input_ids=input_ids)['last_hidden_state'] policy = self.policy_head(features).mean(axis=1) value = self.value_head(features).mean(axis=1).squeeze(1) if labels is None: return {'policy': policy, 'value': value} else: loss_policy = self.loss_policy_fn(policy, labels['labels']) loss_value = self.loss_value_fn(value, labels['values']) loss = loss_policy + loss_value return {'loss_policy': loss_policy, 'loss_value': loss_value, 'loss': loss} ================================================ FILE: src/pl_modules/__init__.py ================================================ from .mlm import MLMModule, MLMDataModule from .policy_value import PolicyValueModule, PolicyValueDataModule def get_pl_modules(cfg): if cfg.model_type == 'MLM': return MLMModule(cfg), MLMDataModule(cfg) elif cfg.model_type == 'PolicyValue': return PolicyValueModule(cfg), PolicyValueDataModule(cfg) else: raise NotImplementedError ================================================ FILE: src/pl_modules/mlm.py ================================================ from pathlib import Path import numpy as np import pytorch_lightning as pl from torch.utils.data import DataLoader from transformers import AdamW from src.data.mlm import MLMDataset from src.model.bert import BertMLM class MLMModule(pl.LightningModule): def __init__(self, hparams): super().__init__() self.hparams = hparams self.model = BertMLM(hparams['model_dir']) def forward(self, batch): input_ids = batch['input_ids'] labels = batch['labels'] return self.model(input_ids=input_ids, labels=labels) def training_step(self, batch, batch_idx): outputs = self(batch) loss = outputs[0] self.log('loss', loss) return loss def validation_step(self, batch, batch_idx): outputs = self(batch) loss = outputs[0].detach().cpu().numpy() return {'loss': loss} def validation_epoch_end(self, outputs): val_loss = np.mean([out['loss'] for out in outputs]) self.log('steps', self.global_step) self.log('val_loss', val_loss) def configure_optimizers(self): return AdamW(self.parameters(), lr=5e-5) class MLMDataModule(pl.LightningDataModule): def __init__(self, cfg): super().__init__() self.cfg = cfg def setup(self, stage=None): dataset_dir = Path(self.cfg.dataset_dir) train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True) valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True) self.train_dataset = MLMDataset(train_data) self.val_dataset = MLMDataset(valid_data) def train_dataloader(self): return DataLoader(self.train_dataset, **self.cfg.train_loader) def val_dataloader(self): return DataLoader(self.val_dataset, **self.cfg.val_loader) ================================================ FILE: src/pl_modules/policy_value.py ================================================ from pathlib import Path import numpy as np import pytorch_lightning as pl from torch.utils.data import DataLoader from transformers import AdamW from src.data.policy_value import PolicyValueDataset from src.model.bert import BertPolicyValue class PolicyValueModule(pl.LightningModule): def __init__(self, hparams): super().__init__() self.hparams = hparams self.model = BertPolicyValue(hparams['model_dir']) def forward(self, input_ids, labels=None): output = self.model(input_ids, labels) return output def training_step(self, batch, batch_idx): input_ids = batch.pop('input_ids') output = self(input_ids, batch) return {'loss': output['loss']} def validation_step(self, batch, batch_idx): input_ids = batch.pop('input_ids') output = self(input_ids, batch) for k, v in output.items(): output[k] = v.detach().cpu().numpy() return output def validation_epoch_end(self, outputs): val_loss = np.mean([out['loss'] for out in outputs]) val_loss_policy = np.mean([out['loss_policy'] for out in outputs]) val_loss_value = np.mean([out['loss_value'] for out in outputs]) self.log('val_loss', val_loss) self.log('val_loss_policy', val_loss_policy) self.log('val_loss_value', val_loss_value) def configure_optimizers(self): return AdamW(self.parameters(), lr=5e-5) class PolicyValueDataModule(pl.LightningDataModule): def __init__(self, cfg): super().__init__() self.cfg = cfg def setup(self, stage=None): dataset_dir = Path(self.cfg.dataset_dir) train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True) print('Load train data') valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True) print('Load val data') self.train_dataset = PolicyValueDataset(train_data) self.val_dataset = PolicyValueDataset(valid_data) def train_dataloader(self): return DataLoader(self.train_dataset, **self.cfg.train_loader) def val_dataloader(self): return DataLoader(self.val_dataset, **self.cfg.val_loader) ================================================ FILE: src/player/__init__.py ================================================ ================================================ FILE: src/player/base_player.py ================================================ import cshogi class BasePlayer: def __init__(self): self.board = cshogi.Board() def usi(self): print('id name bert_player') print('usiok') def usinewgame(self): pass def setoption(self, option): pass def isready(self): pass def position(self, moves): if moves[0] == 'startpos': self.board.reset() for move in moves[2:]: self.board.push_usi(move) elif moves[0] == 'sfen': self.board.set_sfen(' '.join(moves[1:])) # for debug print(self.board.sfen()) def go(self): pass def quit(self): pass ================================================ FILE: src/player/mcts_player.py ================================================ import time from argparse import ArgumentParser from pathlib import Path import cshogi import numpy as np import torch from src.features.common import get_seq_from_board from src.features.policy_value import get_move_label from src.pl_modules.policy_value import PolicyValueModule from src.player.base_player import BasePlayer from src.player.usi import usi from src.uct.uct_node import NOT_EXPANDED, NodeHash, UCT_HASH_SIZE, UctNode from src.utils.misc import boltzmann class MCTSPlayer(BasePlayer): def __init__(self, ckpt_path, playout_halt=1000, temperature=1, resign_threshold=0.01, c_puct=1): super().__init__() self.ckpt_path = ckpt_path self.playout_halt = playout_halt self.temperature = temperature self.resign_threshold = resign_threshold self.c_puct = c_puct self.model = None self.node_hash = NodeHash() self.uct_nodes = [UctNode() for _ in range(UCT_HASH_SIZE)] self.current_n_idx = None self.playout_count = 0 def usi(self): print('id name BERT-MCTS') print('usiok') def isready(self): if self.model is None: self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model self.model.cuda() self.model.eval() self.node_hash.initialize() print('readyok') def go(self): if self.board.is_game_over(): print('bestmove resign') return # 探索情報をクリア self.playout_count = 0 # 古いハッシュを削除 self.node_hash.delete_old_hash(self.board, self.uct_nodes) # 探索開始時刻 begin_time = time.time() # ルートノードの展開 self.current_n_idx = self.expand_node() # 候補手が1つの場合はその手を返す current_node = self.uct_nodes[self.current_n_idx] child_moves = current_node.child_moves child_num = len(child_moves) if child_num == 1: print('bestmove', cshogi.move_to_usi(child_moves[0])) return def get_bestmove_and_print_info(): # 探索にかかった時間を求める finish_time = time.time() - begin_time if self.board.move_number < 5: selected_index = np.random.choice(np.arange(child_num), p=current_node.policy) else: # 訪問回数最大の手を選択する selected_index = np.argmax(current_node.child_moves_count) # 選択したノードの訪問回数0ならポリシーの値 if current_node.child_moves_count[selected_index] == 0: best_wp = current_node.policy[selected_index] # それ以外なら勝率の平均を出す else: best_wp = current_node.child_value_sum[selected_index] / current_node.child_moves_count[selected_index] # 閾値以下なら投了 if best_wp < self.resign_threshold: bestmove = 'resign' else: bestmove = cshogi.move_to_usi(child_moves[selected_index]) # valueを評価値のスケールに変換 if best_wp == 1: cp = 30000 elif best_wp == 0: cp = -30000 else: cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762) nps = int(current_node.move_count / finish_time) time_secs = int(finish_time * 1000) nodes = current_node.move_count hashfull = self.node_hash.get_usage_rate() * 100 print(f'info score cp {cp} hashfull {hashfull:.2f} time {time_secs} nodes {nodes} nps {nps} pv {bestmove}') return bestmove # プレイアウトを繰り返す # 探索回数が閾値を超える、または探索が打ち切られたらループを抜ける while self.playout_count < self.playout_halt: self.playout_count += 1 self.uct_search(self.current_n_idx) # 10回に1回 読みの状況を出力 if (self.playout_count+1) % 10 == 0: get_bestmove_and_print_info() # 探索を打ち切るか確認 if self.interruption_check() or not self.node_hash.enough_size: break bestmove = get_bestmove_and_print_info() print('bestmove', bestmove) def expand_node(self): current_hash = self.board.zobrist_hash() current_turn = self.board.turn current_move_number = self.board.move_number n_idx = self.node_hash.find_same_hash_index(current_hash, current_turn, current_move_number) # 合流先が検知できれば、それを返す if n_idx != UCT_HASH_SIZE: return n_idx # 空のインデックスを探す n_idx = self.node_hash.search_empty_index(current_hash, current_turn, current_move_number) # 現在のノードの初期化 current_node = self.uct_nodes[n_idx] current_node.reset() # 候補手の展開 current_node.child_moves = [move for move in self.board.legal_moves] child_num = len(current_node.child_moves) current_node.child_n_indices = [NOT_EXPANDED for _ in range(child_num)] current_node.child_moves_count = np.zeros(child_num, dtype=np.int32) current_node.child_value_sum = np.zeros(child_num, dtype=np.float32) # ノードを評価 self.eval_node(n_idx) return n_idx def eval_node(self, n_idx): current_node = self.uct_nodes[n_idx] child_moves = current_node.child_moves child_num = len(current_node.child_moves) if child_num == 0: # 指す手がない=負け current_node.value = 0 current_node.evaled = True else: # 現在の局面における方策と価値 seq = get_seq_from_board(self.board) input_ids = torch.tensor([seq], dtype=torch.long).cuda() with torch.no_grad(): output = self.model(input_ids) value = output['value'].detach().cpu().numpy()[0] policy_logits = output['policy'].detach().cpu().numpy()[0] # 合法手でフィルターする legal_move_labels = [] for move in child_moves: legal_move_labels.append(get_move_label(move, self.board.turn)) # Boltzmann policy = boltzmann(policy_logits[legal_move_labels], self.temperature) # ノードの値を更新 current_node.policy = policy current_node.value = value current_node.evaled = True def uct_search(self, n_idx): current_node = self.uct_nodes[n_idx] child_moves = current_node.child_moves child_n_indices = current_node.child_n_indices child_num = len(child_moves) # 詰みは負け->元ノードから見たら勝ち if child_num == 0: return 1 # PUCTアルゴリズムによるUCB値最大の手 next_c_idx = self.select_max_ucb_child(n_idx) next_move = child_moves[next_c_idx] next_n_idx = child_n_indices[next_c_idx] # 選んだ手を着手 self.board.push(next_move) if next_n_idx == NOT_EXPANDED: # 選択した手に対応するコードが未展開なら展開 # ノードの展開(ノード展開処理の中でノードを評価する) next_n_idx = self.expand_node() child_n_indices[next_c_idx] = next_n_idx child_node = self.uct_nodes[next_n_idx] result = 1 - child_node.value else: # 展開済みなら一手深く読む result = self.uct_search(next_n_idx) # バックアップ # 探索結果の反映 current_node.move_count += 1 current_node.child_value_sum[next_c_idx] += result current_node.child_moves_count[next_c_idx] += 1 # 手を戻す self.board.pop(next_move) return 1 - result def select_max_ucb_child(self, c_idx): current_node = self.uct_nodes[c_idx] child_wins_count = current_node.child_value_sum child_moves_count = current_node.child_moves_count child_num = len(child_moves_count) # child_move_countが0の場所のqは0.5で埋める q = np.divide(child_wins_count, child_moves_count, out=np.repeat(0.5, child_num), where=child_moves_count != 0) u = np.sqrt(current_node.move_count) / (1 + child_moves_count) ucb = q + self.c_puct * current_node.policy * u return np.argmax(ucb) # 探索を打ち切るか確認 def interruption_check(self): child_move_count = self.uct_nodes[self.current_n_idx].child_moves_count rest = self.playout_halt - self.playout_count # 探索回数が最も多い手と次に多いてを求める second, first = child_move_count[np.argpartition(child_move_count, -2)[-2:]] # 残りの探索を全て次善手に費やしても最善手を超えられない場合は探索を打ち切る if first - second > rest: return True else: return False def argparse(): parser = ArgumentParser() parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt') args, _ = parser.parse_known_args() print('Command Line Args:') print(args) return args def main(args): ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve() player = MCTSPlayer(ckpt_path) usi(player) if __name__ == '__main__': args = argparse() main(args) ================================================ FILE: src/player/policy_player.py ================================================ from argparse import ArgumentParser from pathlib import Path import numpy as np import cshogi import torch import torch.nn.functional as F from src.features.common import get_seq_from_board from src.features.policy_value import get_move_label from src.pl_modules.policy_value import PolicyValueModule from src.player.base_player import BasePlayer from src.player.usi import usi from src.utils.misc import greedy class PolicyPlayer(BasePlayer): def __init__(self, ckpt_path): super().__init__() self.ckpt_path = ckpt_path self.model = None def usi(self): print('id name BERT-Policy') print('usiok') def isready(self): if self.model is None: self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model self.model.cuda() self.model.eval() print('readyok') def go(self): if self.board.is_game_over(): print('bestmove resign') return seq = get_seq_from_board(self.board) input_ids = torch.tensor([seq], dtype=torch.long).cuda() with torch.no_grad(): output = self.model(input_ids) policy = F.softmax(output['policy'], dim=1).detach().cpu().numpy()[0] # 全ての合法手について legal_moves = [] legal_policy = [] for move in self.board.legal_moves: # ラベルに変換 label = get_move_label(move, self.board.turn) # 合法手とその指し手の確率(logits)を格納 legal_moves.append(move) legal_policy.append(policy[label]) selected_index = greedy(legal_policy) bestmove = cshogi.move_to_usi(legal_moves[selected_index]) best_wp = legal_policy[selected_index] # valueを評価値のスケールに変換 if best_wp == 1: cp = 30000 elif best_wp == 0: cp = -30000 else: cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762) print(f'info score cp {cp} pv {bestmove}') print('bestmove', bestmove) def argparse(): parser = ArgumentParser() parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt') args, _ = parser.parse_known_args() print('Command Line Args:') print(args) return args def main(args): ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve() player = PolicyPlayer(ckpt_path) usi(player) if __name__ == '__main__': args = argparse() main(args) ================================================ FILE: src/player/usi.py ================================================ from src.player.base_player import BasePlayer def usi(player: BasePlayer): while True: cmd_line = input() cmd = cmd_line.split(' ', 1) cmd = [c.rstrip() for c in cmd] if cmd[0] == 'usi': player.usi() elif cmd[0] == 'setoption': option = cmd[1].split(' ') player.setoption(option) elif cmd[0] == 'isready': player.isready() elif cmd[0] == 'usinewgame': player.usinewgame() elif cmd[0] == 'position': moves = cmd[1].split(' ') player.position(moves) elif cmd[0] == 'go': player.go() elif cmd[0] == 'quit': player.quit() break ================================================ FILE: src/uct/__init__.py ================================================ ================================================ FILE: src/uct/uct_node.py ================================================ # ノードの上限値 UCT_HASH_SIZE = 4096 # 未展開のノードのインデックス NOT_EXPANDED = -1 # ゾブリストハッシュ値をUCT_HASH_SIZEに圧縮 def hash_to_index(zhash): return ((zhash & 0xffffffff) ^ ((zhash >> 32) & 0xffffffff)) & (UCT_HASH_SIZE - 1) class NodeHashEntry: def __init__(self): self.hash = 0 # ゾブリストハッシュの値 self.color = 0 # 手番 self.moves = 0 # ゲーム開始からの手数 self.flag = False # 使用中か識別するフラグ def reset(self): self.hash = 0 self.color = 0 self.moves = 0 self.flag = False class NodeHash: def __init__(self): self.used = 0 self.enough_size = True self.node_hash = None def initialize(self): self.used = 0 self.enough_size = True if self.node_hash is None: self.node_hash = [NodeHashEntry() for _ in range(UCT_HASH_SIZE)] else: for i in range(UCT_HASH_SIZE): self.node_hash[i].reset() # 未使用のインデックスを探して返す def search_empty_index(self, zhash, color, moves): key = hash_to_index(zhash) i = key while True: if not self.node_hash[i].flag: self.node_hash[i].hash = zhash self.node_hash[i].color = color self.node_hash[i].moves = moves self.node_hash[i].flag = True self.used += 1 if self.get_usage_rate() > 0.9: self.enough_size = False return i i += 1 if i >= UCT_HASH_SIZE: i = 0 # もとに戻ってくる if i == key: return UCT_HASH_SIZE # ハッシュ値に対応するインデックスを返す def find_same_hash_index(self, zhash, color, moves): key = hash_to_index(zhash) i = key while True: # もろもろの属性があっていたらiを返す if self.node_hash[i].flag and self.node_hash[i].hash == zhash and self.node_hash[i].color == color and \ self.node_hash[i].moves == moves: return i else: return UCT_HASH_SIZE # 使用中のノードを残す def save_used_hash(self, board, uct_nodes, n_idx): self.node_hash[n_idx].flag = True self.used += 1 current_node = uct_nodes[n_idx] child_n_indices = current_node.child_n_indices child_moves = current_node.child_moves child_num = len(child_moves) for i in range(child_num): if child_n_indices[i] != NOT_EXPANDED and not self.node_hash[child_n_indices[i]].flag: board.push(child_moves[i]) self.save_used_hash(board, uct_nodes, child_n_indices[i]) board.pop(child_moves[i]) # 古いハッシュを削除 def delete_old_hash(self, board, uct_node): # 現在の局面をルートとする局面以外を削除する n_idx = self.find_same_hash_index(board.zobrist_hash(), board.turn, board.move_number) self.used = 0 for i in range(UCT_HASH_SIZE): self.node_hash[i].reset() if n_idx != UCT_HASH_SIZE: self.save_used_hash(board, uct_node, n_idx) self.enough_size = True def get_usage_rate(self): return self.used / UCT_HASH_SIZE class UctNode: def __init__(self): self.evaled = False # 評価済みフラグ self.move_count = 0 # ノードの訪問回数 self.value = 0 # ノードの価値ネットワークの評価(予測勝率) self.policy = None # 正規化した方策ネットワークの出力(子ノード分の長さを持つ) self.child_moves = None # 子ノードの指し手 self.child_n_indices = None # 子ノードのインデックス self.child_moves_count = None # 子ノードの訪問回数 (UCB用) self.child_value_sum = None # 子ノードのvalueの合計 (UCB用) def reset(self): self.evaled = False self.move_count = 0 self.value = 0 self.policy = None self.child_moves = None self.child_n_indices = None self.child_moves_count = None self.child_value_sum = None ================================================ FILE: src/utils/__init__.py ================================================ ================================================ FILE: src/utils/hcpe.py ================================================ import cshogi import numpy as np from tqdm import tqdm from src.features.policy_value import get_policy_value_label def get_data_from_hcpe(hcpes): data = [] for hcpe in tqdm(hcpes): data.append(get_policy_value_label(hcpe)) data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')]) return data def load_hcpes(hcpe_paths): hcpe_list = [] for hcpe_path in hcpe_paths: hcpe_list.append(np.fromfile(hcpe_path, dtype=cshogi.HuffmanCodedPosAndEval)) hcpes = np.concatenate(hcpe_list) return hcpes ================================================ FILE: src/utils/misc.py ================================================ import numpy as np def greedy(logits): return np.asarray(logits).argmax() def boltzmann(logits, temperature): logits /= temperature logits -= logits.max() probabilities = np.exp(logits) probabilities /= probabilities.sum() return probabilities ================================================ FILE: src/utils/sfen.py ================================================ from pathlib import Path import numpy as np from src.features.policy_value import get_moves_from_lines, get_policy_value_label_from_moves def get_gokaku_sfen_paths(base_dir: Path, max_num_of_moves=40): assert base_dir.name == 'ShogiGokakuKyokumen' sfen_paths = [] # TODO: 正規表現にする for sfen_path in sorted(list(base_dir.glob('*/*/*.sfen')) + list(base_dir.glob('*/*/*/*.sfen'))): # 互角で40手以下の棋譜 if '互角' in sfen_path.stem and (int(sfen_path.stem.split('手目')[0][-3:]) <= max_num_of_moves): sfen_paths.append(sfen_path) return sfen_paths def get_data_from_sfen(sfens): # sfen形式のデータは棋譜の途中を使って学習データを作成する data = [] for sfen in sfens: moves = get_moves_from_lines(sfen) # 手数が6手以上のものを使う if len(moves) < 6: continue # 30手なら6局面を抽出 for _ in range(max(1, len(moves) // 5)): dt = get_policy_value_label_from_moves(moves) data.append(dt) data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')]) return data def load_sfen(sfen_path): with open(sfen_path) as f: sfens = [l.rstrip() for l in f.readlines()] return sfens def load_sfens(sfen_paths): sfens = [] for kifu_path in sfen_paths: sfens.extend(load_sfen(kifu_path)) return sfens ================================================ FILE: src/utils/shogi.py ================================================ import numpy as np # 移動の定数 MOVE_DIRECTION = [ UP, UP_LEFT, UP_RIGHT, LEFT, RIGHT, DOWN, DOWN_LEFT, DOWN_RIGHT, UP2_LEFT, UP2_RIGHT, UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE, DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE ] = range(20) # 成り変換テーブル MOVE_DIRECTION_PROMOTED = [ UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE, DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE ] # 指し手を表すラベルの数 MOVE_DIRECTION_LABEL_NUM = len(MOVE_DIRECTION) + 7 # 7は持ち駒の種類 # 先手駒と後手駒を逆に変換する辞書(reverse_piece_fn)の用意 bw_dict = {k: k + 16 for k in range(1, 15)} # 1~14に自駒, 17~30に敵駒 wb_dict = {v: k for k, v in bw_dict.items()} pieces_dict = {**bw_dict, **wb_dict, 0: 0} # 0は空きマス pieces_list = list(pieces_dict.keys()) reverse_piece_fn = np.vectorize(pieces_dict.get) ================================================ FILE: tools/download_and_build_lesserkai.sh ================================================ wget http://shogidokoro.starfree.jp/download/LesserkaiSrc.zip -P ./work_dirs unzip ./work_dirs/LesserkaiSrc.zip -d ./work_dirs/ rm ./work_dirs/LesserkaiSrc.zip cd ./work_dirs/LesserkaiSrc/Lesserkai make ================================================ FILE: tools/make_dataset.py ================================================ from pathlib import Path import numpy as np from sklearn.model_selection import train_test_split from src.utils.hcpe import get_data_from_hcpe, load_hcpes from src.utils.sfen import get_data_from_sfen, get_gokaku_sfen_paths, load_sfens def main(): gokaku_dir = Path('./data/ShogiGokakuKyokumen/') hcpe_dir = Path('./data/hcpe/') dataset_base_dir = Path('./data/dataset/') dataset_base_dir.mkdir(exist_ok=True) for num_of_moves in [40, 100]: sfen_paths = get_gokaku_sfen_paths(gokaku_dir, num_of_moves) sfens = load_sfens(sfen_paths) data = get_data_from_sfen(sfens) train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42) dataset_dir = dataset_base_dir / f'gokaku_{num_of_moves:03d}' dataset_dir.mkdir() np.save(dataset_dir / 'train.npy', train_data) np.save(dataset_dir / 'val.npy', valid_data) # selfplayの棋譜 hcpe_paths = sorted(hcpe_dir.glob('selfplay-*')) hcpes = load_hcpes(hcpe_paths) data = get_data_from_hcpe(hcpes) train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42) dataset_dir = dataset_base_dir / 'selfplay' dataset_dir.mkdir() np.save(dataset_dir / 'train.npy', train_data) np.save(dataset_dir / 'val.npy', valid_data) if __name__ == '__main__': main() ================================================ FILE: tools/pl_to_transformers.py ================================================ from argparse import ArgumentParser from pathlib import Path import torch from src.model.bert import config def argparse(): parser = ArgumentParser(description='Convert pl checkpoint to transformers format') parser.add_argument('ckpt_path', type=str) args, _ = parser.parse_known_args() return args def main(args): ckpt_path = Path(args.ckpt_path) ckpt_dir = ckpt_path.parent ckpt = torch.load(ckpt_path) state_dict = ckpt['state_dict'] state_dict = {'.'.join(k.split('.')[2:]): v for k, v in state_dict.items()} # 同一ディレクトリにpytorch_model.binとconfig.jsonが必要 state_dict_path = ckpt_dir / f'pytorch_model.bin' torch.save(state_dict, state_dict_path) config.to_json_file(ckpt_dir / 'config.json') if __name__ == '__main__': args = argparse() main(args) ================================================ FILE: tools/test_engine.py ================================================ from argparse import ArgumentParser from pathlib import Path from cshogi import cli def parse_args(): parser = ArgumentParser() parser.add_argument('engine_path') return parser.parse_args() def main(args): benchmark = str(Path('./work_dirs/LesserkaiSrc/Lesserkai/Lesserkai').resolve()) engine = str((Path(args.engine_path)).resolve()) print('benchmark: ', benchmark) print('engine: ', engine) cli.main(benchmark, engine) if __name__ == '__main__': args = parse_args() main(args) ================================================ FILE: tools/train.py ================================================ from argparse import ArgumentParser from pathlib import Path import pytorch_lightning as pl from omegaconf import OmegaConf from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import ModelCheckpoint from src.pl_modules import get_pl_modules def argparse(): parser = ArgumentParser() parser.add_argument('cfg', type=str) parser.add_argument('--log_dir', type=str, default='./work_dirs') parser.add_argument('--ckpt_path', type=str, default='') parser = pl.Trainer.add_argparse_args(parser) args, _ = parser.parse_known_args() return args def main(args): cfg = OmegaConf.load(args.cfg) pl.seed_everything(cfg.seed) # configのtrain_paramsをargsに反映 for k, v in cfg.train_params.items(): setattr(args, k, v) # Disable default checkpoint callback args.checkpoint_callback = False trainer = pl.Trainer.from_argparse_args(args) trainer.logger = pl_loggers.TensorBoardLogger(save_dir=args.log_dir, name=Path(args.cfg).stem, default_hp_metric=False) trainer.callbacks.append(ModelCheckpoint(filename='{step:07d}-{val_loss:.2f}', monitor='val_loss', save_top_k=1, save_last=True)) model, data = get_pl_modules(cfg) if args.ckpt_path: model = model.load_from_checkpoint(args.ckpt_path) trainer.fit(model, datamodule=data) if __name__ == '__main__': args = argparse() main(args)