[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# Project specific folders\nlightning_logs/\ntf_logs/\nremote_logs/\nlocal_logs/\n*_logs/\ncheckpoints/\ndataset/\nlocal_misc/\nnotebooks/data/\n\n# LSF logfiles\nlsf.*\n\n# IDEA folders\n.idea/\n\n/work_dirs/*\n/data/*\n!.gitkeep\n"
  },
  {
    "path": "Makefile",
    "content": "PROJECT ?= bert-mcts\nDATADIR ?= ${PWD}/data\nWORKSPACE ?= /workspace/$(PROJECT)\nDOCKER_IMAGE ?= ${PROJECT}:latest\n\nSHMSIZE ?= 100G\nDOCKER_OPTS := \\\n\t\t\t--name ${PROJECT} \\\n\t\t\t--rm -it \\\n\t\t\t--shm-size=${SHMSIZE} \\\n\t\t\t-v ${PWD}:${WORKSPACE} \\\n\t\t\t-v ${DATADIR}:${WORKSPACE}/data \\\n\t\t\t-v ${LOG_DIR}:${WORKSPACE}/work_dirs/logs \\\n      -w ${WORKSPACE} \\\n\t\t\t--ipc=host \\\n\t\t\t--network=host \\\n\t\t\t--gpus all\n\ndocker-build:\n\tdocker build -f docker/Dockerfile -t ${DOCKER_IMAGE} .\n\ndocker-start-interactive: docker-build\n\tdocker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash\n\ndocker-start-jupyter: docker-build\n\tdocker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \\\n\t\tbash -c \"jupyter lab --port=8888 --ip=0.0.0.0 --allow-root --no-browser\"\n\ndocker-run: docker-build\n\tdocker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \\\n\t\tbash -c \"${COMMAND}\"\n"
  },
  {
    "path": "README.md",
    "content": "# BERT-MCTS-YOUTUBE\n\nYouTubeにてヨビノリたくみさんと対戦した将棋ソフトです。\n自然言語モデルであるBERTとモンテカルロ木探索(MCTS)の組み合わせで出来ています。  \nすべてpythonで書いてあるため、探索の速度は遅いです。  \n\nBERT以外の大部分は『将棋AIで学ぶディープラーニング』を参考に書いています。\n- [書籍(amazon)](https://www.amazon.co.jp/dp/B07B7JJ929)\n- [github](https://github.com/TadaoYamaoka/python-dlshogi)\n\n## 環境\n\n### Colab\n\nテストするだけなら[google colab](https://colab.research.google.com/drive/10KAuLlNe6FKZBp_iE2bQJPNhoY2WeACx?usp=sharing) が簡単です。\n\n以下はローカルで試す場合。CPUだと遅いのでCUDA環境が望ましいです。  \n重みファイルは[ここ](https://drive.google.com/drive/folders/1N-Np2NmNLtLGS9gjnreBkYdTxrH1EHFw?usp=sharing) にアップしてあり、\nたくみさんと戦った重みファイルがyoutube_version.ckpt、追加で数日間学習させた重みファイルがlatest.ckptになります。  \nダウンロード先のパスはengine/***_player.sh内で指定してください。\nデフォルトではwork_dirs以下にダウンロードすることを想定しています。\n\n### Docker\n\ncuda10.2以上のnvidia-dockerが整っているなら次のコマンドで環境に入れます。\n```bash\n$ make docker-start-interactive\n```\n\n### Ubuntu18.04\n\ncuda10.2でanacondaが入っていれば次のコマンドで仮想環境を作れます。\n```bash\n$ conda env create -f env_name.yml\n$ conda activate bert-mcts-youtube\n$ python setup.py develop\n```\n\n### Windows10\n\n未検証\n\n## 将棋エンジンのテスト\n\nエンジンはengineディレクトリに用意しています。これらはShogiGUIなどから呼び出すことができます。\n- policy_player.shはBERTの出力する方策のみを頼りに指すモデル(弱い)\n- mcts_player.shはBERTの出力をもとにMCTSで探索するモデル\n\n## 学習\n\n学習には互角局面集とGCTの自己対戦棋譜を用いました。  \nモデルはMasked Language Modelで事前学習してから、Policy Value Networkの学習という手順を踏みます。  \nただし、将棋は良質な教師データが大量にあるため事前学習の効果はあまりない気がします。\n\n### データの準備\n\n互角局面集のダウンロード\n```bash\n$ cd data\n$ git clone https://github.com/tttak/ShogiGokakuKyokumen.git\n```\n\nGCTの自己対戦棋譜\n```bash\n$ cd data\n$ mkdir hcpe\n```\n\nGCTの自己対戦棋譜は開発者の加納さんが[リンク](https://drive.google.com/drive/folders/14FaqqIHRctTQIY6hScCFXWQQZ_pSU3-F)\nに公開してくださっていまし。\nここからselfplay-***となっているファイルをいくつかdata/hcpe以下にダウンロードしてください。  \nサイズが大きいので一個でも十分な量あります。\n\nこれらを準備できたら以下のコマンドでデータセットを作ります。\n\n```bash\n$ python tools/make_dataset.py\n```\n\n### Masked Language Modelの学習\n\n```bash\n$ python tools/train.py configs/mlm_base.yaml\n```\n\n### 重みファイルの変換\n\nMasked Language Modelのチェックポイントをtransformers形式に変換しておきます。\nこれによって転移学習のコードが多少書きやすくなります。\n\n```bash\n$ python tools/pl_to_transformers.py work_dirs/mlm_base/version_0/checkpoints/last.ckpt\n```\n\n### Policy Value Modelの学習\n\n最後にこれらを使ってPolicy Valueを学習させます。\n\n```bash\n$ python tools/train.py configs/policy_value.yaml\n```"
  },
  {
    "path": "configs/.gitkeep",
    "content": ""
  },
  {
    "path": "configs/mlm_base.yaml",
    "content": "model_type: 'MLM'\n\nseed: 42\ndataset_dir: './data/dataset/gokaku_100'\nmodel_dir:\n\ntrain_loader:\n  batch_size: 64\n  shuffle: True\n  num_workers: 8\n  pin_memory: False\n  drop_last: True\n\nval_loader:\n  batch_size: 64\n  shuffle: False\n  num_workers: 8\n  pin_memory: False\n  drop_last: False\n\ntrain_params:\n  max_epochs: 5\n  # validationおよびcheckpointの間隔step数\n  val_check_interval: 3000\n  # 環境に応じて，適宜変更\n  gpus: [0]"
  },
  {
    "path": "configs/policy_value_base.yaml",
    "content": "model_type: 'PolicyValue'\n\nseed: 42\ndataset_dir: './data/dataset/selfplay'\nmodel_dir: './work_dirs/mlm_base/version_0/checkpoints'\n\ntrain_loader:\n  batch_size: 128\n  shuffle: True\n  num_workers: 0\n  pin_memory: True\n  drop_last: True\n\nval_loader:\n  batch_size: 128\n  shuffle: False\n  num_workers: 0\n  pin_memory: True\n  drop_last: False\n\ntrain_params:\n  max_epochs: 1\n  # validationおよびcheckpointの間隔step数\n  val_check_interval: 30000\n  limit_val_batches: 0.1\n  # 環境に応じて，適宜変更\n  gpus: [0]\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime\nENV DEBIAN_FRONTEND=noninteractive\nADD requirements.txt /tmp\nRUN pip install -r /tmp/requirements.txt\n\nADD docker/entrypoint.sh /tmp\nENTRYPOINT [\"bash\", \"/tmp/entrypoint.sh\"]\n"
  },
  {
    "path": "docker/entrypoint.sh",
    "content": "python setup.py develop\nexec \"$@\"\n"
  },
  {
    "path": "engine/mcts_player.sh",
    "content": "#!/bin/sh\npython -m src.player.mcts_player --ckpt_path ./work_dirs/youtube_version.ckpt\n"
  },
  {
    "path": "engine/policy_player.sh",
    "content": "#!/bin/sh\npython -m src.player.policy_player --ckpt_path ./work_dirs/youtube_version.ckpt\n"
  },
  {
    "path": "env_name.yml",
    "content": "name: bert-mcts-youtube\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - alabaster=0.7.12=py37_0\n  - anaconda=2020.11=py37_0\n  - anaconda-client=1.7.2=py37_0\n  - anaconda-project=0.8.4=py_0\n  - argh=0.26.2=py37_0\n  - argon2-cffi=20.1.0=py37h7b6447c_1\n  - asn1crypto=1.4.0=py_0\n  - astroid=2.4.2=py37_0\n  - astropy=4.0.2=py37h7b6447c_0\n  - async_generator=1.10=py37h28b3542_0\n  - atomicwrites=1.4.0=py_0\n  - attrs=20.3.0=pyhd3eb1b0_0\n  - autopep8=1.5.4=py_0\n  - babel=2.8.1=pyhd3eb1b0_0\n  - backcall=0.2.0=py_0\n  - backports=1.0=py_2\n  - backports.shutil_get_terminal_size=1.0.0=py37_2\n  - beautifulsoup4=4.9.3=pyhb0f4dca_0\n  - bitarray=1.6.1=py37h27cfd23_0\n  - bkcharts=0.2=py37_0\n  - blas=1.0=mkl\n  - bleach=3.2.1=py_0\n  - blosc=1.20.1=hd408876_0\n  - bokeh=2.2.3=py37_0\n  - boto=2.49.0=py37_0\n  - bottleneck=1.3.2=py37heb32a55_1\n  - brotlipy=0.7.0=py37h7b6447c_1000\n  - bzip2=1.0.8=h7b6447c_0\n  - ca-certificates=2020.10.14=0\n  - cairo=1.14.12=h8948797_3\n  - certifi=2020.6.20=pyhd3eb1b0_3\n  - cffi=1.14.3=py37he30daa8_0\n  - chardet=3.0.4=py37_1003\n  - click=7.1.2=py_0\n  - cloudpickle=1.6.0=py_0\n  - clyent=1.2.2=py37_1\n  - colorama=0.4.4=py_0\n  - contextlib2=0.6.0.post1=py_0\n  - cryptography=3.1.1=py37h1ba5d50_0\n  - cudatoolkit=10.2.89=hfd86e86_1\n  - curl=7.71.1=hbc83047_1\n  - cycler=0.10.0=py37_0\n  - cython=0.29.21=py37he6710b0_0\n  - cytoolz=0.11.0=py37h7b6447c_0\n  - dask=2.30.0=py_0\n  - dask-core=2.30.0=py_0\n  - dbus=1.13.18=hb2f20db_0\n  - decorator=4.4.2=py_0\n  - defusedxml=0.6.0=py_0\n  - diff-match-patch=20200713=py_0\n  - distributed=2.30.1=py37h06a4308_0\n  - docutils=0.16=py37_1\n  - entrypoints=0.3=py37_0\n  - et_xmlfile=1.0.1=py_1001\n  - expat=2.2.10=he6710b0_2\n  - fastcache=1.1.0=py37h7b6447c_0\n  - ffmpeg=4.3=hf484d3e_0\n  - filelock=3.0.12=py_0\n  - flake8=3.8.4=py_0\n  - flask=1.1.2=py_0\n  - fontconfig=2.13.0=h9420a91_0\n  - freetype=2.10.4=h5ab3b9f_0\n  - fribidi=1.0.10=h7b6447c_0\n  - fsspec=0.8.3=py_0\n  - future=0.18.2=py37_1\n  - get_terminal_size=1.0.0=haa9412d_0\n  - gevent=20.9.0=py37h7b6447c_0\n  - glib=2.66.1=h92f7085_0\n  - glob2=0.7=py_0\n  - gmp=6.1.2=h6c8ec71_1\n  - gmpy2=2.0.8=py37h10f8cd9_2\n  - gnutls=3.6.15=he1e5248_0\n  - graphite2=1.3.14=h23475e2_0\n  - greenlet=0.4.17=py37h7b6447c_0\n  - gst-plugins-base=1.14.0=hbbd80ab_1\n  - gstreamer=1.14.0=hb31296c_0\n  - h5py=2.10.0=py37h7918eee_0\n  - harfbuzz=2.4.0=hca77d97_1\n  - hdf5=1.10.4=hb1b8bf9_0\n  - heapdict=1.0.1=py_0\n  - html5lib=1.1=py_0\n  - icu=58.2=he6710b0_3\n  - idna=2.10=py_0\n  - imageio=2.9.0=py_0\n  - imagesize=1.2.0=py_0\n  - importlib-metadata=2.0.0=py_1\n  - importlib_metadata=2.0.0=1\n  - iniconfig=1.1.1=py_0\n  - intel-openmp=2020.2=254\n  - intervaltree=3.1.0=py_0\n  - ipykernel=5.3.4=py37h5ca1d4c_0\n  - ipython=7.19.0=py37hb070fc8_0\n  - ipython_genutils=0.2.0=py37_0\n  - ipywidgets=7.5.1=py_1\n  - isort=5.6.4=py_0\n  - itsdangerous=1.1.0=py37_0\n  - jbig=2.1=hdba287a_0\n  - jdcal=1.4.1=py_0\n  - jedi=0.17.1=py37_0\n  - jeepney=0.5.0=pyhd3eb1b0_0\n  - jinja2=2.11.2=py_0\n  - joblib=0.17.0=py_0\n  - jpeg=9b=h024ee3a_2\n  - json5=0.9.5=py_0\n  - jsonschema=3.2.0=py_2\n  - jupyter=1.0.0=py37_7\n  - jupyter_client=6.1.7=py_0\n  - jupyter_console=6.2.0=py_0\n  - jupyter_core=4.6.3=py37_0\n  - jupyterlab=2.2.6=py_0\n  - jupyterlab_pygments=0.1.2=py_0\n  - jupyterlab_server=1.2.0=py_0\n  - keyring=21.4.0=py37_1\n  - kiwisolver=1.3.0=py37h2531618_0\n  - krb5=1.18.2=h173b8e3_0\n  - lame=3.100=h7b6447c_0\n  - lazy-object-proxy=1.4.3=py37h7b6447c_0\n  - lcms2=2.11=h396b838_0\n  - ld_impl_linux-64=2.33.1=h53a641e_7\n  - libarchive=3.4.2=h62408e4_0\n  - libcurl=7.71.1=h20c2e04_1\n  - libedit=3.1.20191231=h14c3975_1\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=9.1.0=hdf63c60_0\n  - libgfortran-ng=7.3.0=hdf63c60_0\n  - libiconv=1.15=h63c8f33_5\n  - libidn2=2.3.0=h27cfd23_0\n  - liblief=0.10.1=he6710b0_0\n  - libllvm10=10.0.1=hbcb73fb_5\n  - libpng=1.6.37=hbc83047_0\n  - libsodium=1.0.18=h7b6447c_0\n  - libspatialindex=1.9.3=he6710b0_0\n  - libssh2=1.9.0=h1ba5d50_1\n  - libstdcxx-ng=9.1.0=hdf63c60_0\n  - libtasn1=4.16.0=h27cfd23_0\n  - libtiff=4.1.0=h2733197_1\n  - libtool=2.4.6=h7b6447c_1005\n  - libunistring=0.9.10=h27cfd23_0\n  - libuuid=1.0.3=h1bed415_2\n  - libuv=1.40.0=h7b6447c_0\n  - libxcb=1.14=h7b6447c_0\n  - libxml2=2.9.10=hb55368b_3\n  - libxslt=1.1.34=hc22bd24_0\n  - llvmlite=0.34.0=py37h269e1b5_4\n  - locket=0.2.0=py37_1\n  - lxml=4.6.1=py37hefd8a0e_0\n  - lz4-c=1.9.2=heb0550a_3\n  - lzo=2.10=h7b6447c_2\n  - markupsafe=1.1.1=py37h14c3975_1\n  - matplotlib=3.3.2=0\n  - matplotlib-base=3.3.2=py37h817c723_0\n  - mccabe=0.6.1=py37_1\n  - mistune=0.8.4=py37h14c3975_1001\n  - mkl=2020.2=256\n  - mkl-service=2.3.0=py37he904b0f_0\n  - mkl_fft=1.2.0=py37h23d657b_0\n  - mkl_random=1.1.1=py37h0573a6f_0\n  - mock=4.0.2=py_0\n  - more-itertools=8.6.0=pyhd3eb1b0_0\n  - mpc=1.1.0=h10f8cd9_1\n  - mpfr=4.0.2=hb69a4c5_1\n  - mpmath=1.1.0=py37_0\n  - msgpack-python=1.0.0=py37hfd86e86_1\n  - multipledispatch=0.6.0=py37_0\n  - nbclient=0.5.1=py_0\n  - nbconvert=6.0.7=py37_0\n  - nbformat=5.0.8=py_0\n  - ncurses=6.2=he6710b0_1\n  - nest-asyncio=1.4.2=pyhd3eb1b0_0\n  - nettle=3.7.2=hbbd107a_1\n  - networkx=2.5=py_0\n  - ninja=1.10.2=py37hff7bd54_0\n  - nltk=3.5=py_0\n  - nose=1.3.7=py37_1004\n  - notebook=6.1.4=py37_0\n  - numba=0.51.2=py37h04863e7_1\n  - numexpr=2.7.1=py37h423224d_0\n  - numpy=1.19.2=py37h54aff64_0\n  - numpy-base=1.19.2=py37hfa32c7d_0\n  - numpydoc=1.1.0=pyhd3eb1b0_1\n  - olefile=0.46=py37_0\n  - openh264=2.1.0=hd408876_0\n  - openpyxl=3.0.5=py_0\n  - openssl=1.1.1h=h7b6447c_0\n  - packaging=20.4=py_0\n  - pandas=1.1.3=py37he6710b0_0\n  - pandoc=2.11=hb0f4dca_0\n  - pandocfilters=1.4.3=py37h06a4308_1\n  - pango=1.45.3=hd140c19_0\n  - parso=0.7.0=py_0\n  - partd=1.1.0=py_0\n  - patchelf=0.12=he6710b0_0\n  - path=15.0.0=py37_0\n  - path.py=12.5.0=0\n  - pathlib2=2.3.5=py37_1\n  - pathtools=0.1.2=py_1\n  - patsy=0.5.1=py37_0\n  - pcre=8.44=he6710b0_0\n  - pep8=1.7.1=py37_0\n  - pexpect=4.8.0=py37_1\n  - pickleshare=0.7.5=py37_1001\n  - pillow=8.0.1=py37he98fc37_0\n  - pip=20.2.4=py37h06a4308_0\n  - pixman=0.40.0=h7b6447c_0\n  - pkginfo=1.6.1=py37h06a4308_0\n  - pluggy=0.13.1=py37_0\n  - ply=3.11=py37_0\n  - prometheus_client=0.8.0=py_0\n  - prompt-toolkit=3.0.8=py_0\n  - prompt_toolkit=3.0.8=0\n  - psutil=5.7.2=py37h7b6447c_0\n  - ptyprocess=0.6.0=py37_0\n  - py=1.9.0=py_0\n  - py-lief=0.10.1=py37h403a769_0\n  - pycodestyle=2.6.0=py_0\n  - pycosat=0.6.3=py37h7b6447c_0\n  - pycparser=2.20=py_2\n  - pycrypto=2.6.1=py37h7b6447c_10\n  - pycurl=7.43.0.6=py37h1ba5d50_0\n  - pydocstyle=5.1.1=py_0\n  - pyflakes=2.2.0=py_0\n  - pygments=2.7.2=pyhd3eb1b0_0\n  - pylint=2.6.0=py37_0\n  - pyodbc=4.0.30=py37he6710b0_0\n  - pyopenssl=19.1.0=py_1\n  - pyparsing=2.4.7=py_0\n  - pyqt=5.9.2=py37h05f1152_2\n  - pyrsistent=0.17.3=py37h7b6447c_0\n  - pysocks=1.7.1=py37_1\n  - pytables=3.6.1=py37h71ec239_0\n  - pytest=6.1.1=py37_0\n  - python=3.7.9=h7579374_0\n  - python-dateutil=2.8.1=py_0\n  - python-jsonrpc-server=0.4.0=py_0\n  - python-language-server=0.35.1=py_0\n  - python-libarchive-c=2.9=py_0\n  - pytorch=1.8.1=py3.7_cuda10.2_cudnn7.6.5_0\n  - pytz=2020.1=py_0\n  - pywavelets=1.1.1=py37h7b6447c_2\n  - pyxdg=0.27=pyhd3eb1b0_0\n  - pyyaml=5.3.1=py37h7b6447c_1\n  - pyzmq=19.0.2=py37he6710b0_1\n  - qdarkstyle=2.8.1=py_0\n  - qt=5.9.7=h5867ecd_1\n  - qtawesome=1.0.1=py_0\n  - qtconsole=4.7.7=py_0\n  - qtpy=1.9.0=py_0\n  - readline=8.0=h7b6447c_0\n  - regex=2020.10.15=py37h7b6447c_0\n  - requests=2.24.0=py_0\n  - ripgrep=12.1.1=0\n  - rope=0.18.0=py_0\n  - rtree=0.9.4=py37_1\n  - ruamel_yaml=0.15.87=py37h7b6447c_1\n  - scikit-image=0.17.2=py37hdf5156a_0\n  - scikit-learn=0.23.2=py37h0573a6f_0\n  - scipy=1.5.2=py37h0b6359f_0\n  - seaborn=0.11.0=py_0\n  - secretstorage=3.1.2=py37_1\n  - send2trash=1.5.0=py37_0\n  - setuptools=50.3.1=py37h06a4308_1\n  - simplegeneric=0.8.1=py37_2\n  - singledispatch=3.4.0.3=py_1001\n  - sip=4.19.8=py37hf484d3e_0\n  - six=1.15.0=py37h06a4308_0\n  - snowballstemmer=2.0.0=py_0\n  - sortedcollections=1.2.1=py_0\n  - sortedcontainers=2.2.2=py_0\n  - soupsieve=2.0.1=py_0\n  - sphinx=3.2.1=py_0\n  - sphinxcontrib=1.0=py37_1\n  - sphinxcontrib-applehelp=1.0.2=py_0\n  - sphinxcontrib-devhelp=1.0.2=py_0\n  - sphinxcontrib-htmlhelp=1.0.3=py_0\n  - sphinxcontrib-jsmath=1.0.1=py_0\n  - sphinxcontrib-qthelp=1.0.3=py_0\n  - sphinxcontrib-serializinghtml=1.1.4=py_0\n  - sphinxcontrib-websupport=1.2.4=py_0\n  - spyder=4.1.5=py37_0\n  - spyder-kernels=1.9.4=py37_0\n  - sqlalchemy=1.3.20=py37h27cfd23_0\n  - sqlite=3.33.0=h62c20be_0\n  - statsmodels=0.12.0=py37h7b6447c_0\n  - sympy=1.6.2=py37h06a4308_1\n  - tbb=2020.3=hfd86e86_0\n  - tblib=1.7.0=py_0\n  - terminado=0.9.1=py37_0\n  - testpath=0.4.4=py_0\n  - threadpoolctl=2.1.0=pyh5ca1d4c_0\n  - tifffile=2020.10.1=py37hdd07704_2\n  - tk=8.6.10=hbc83047_0\n  - toml=0.10.1=py_0\n  - toolz=0.11.1=py_0\n  - torchvision=0.9.1=py37_cu102\n  - tornado=6.0.4=py37h7b6447c_1\n  - tqdm=4.50.2=py_0\n  - traitlets=5.0.5=py_0\n  - typed-ast=1.4.1=py37h7b6447c_0\n  - typing_extensions=3.7.4.3=py_0\n  - ujson=4.0.1=py37he6710b0_0\n  - unicodecsv=0.14.1=py37_0\n  - unixodbc=2.3.9=h7b6447c_0\n  - urllib3=1.25.11=py_0\n  - watchdog=0.10.3=py37_0\n  - wcwidth=0.2.5=py_0\n  - webencodings=0.5.1=py37_1\n  - werkzeug=1.0.1=py_0\n  - wheel=0.35.1=py_0\n  - widgetsnbextension=3.5.1=py37_0\n  - wrapt=1.11.2=py37h7b6447c_0\n  - wurlitzer=2.0.1=py37_0\n  - xlrd=1.2.0=py37_0\n  - xlsxwriter=1.3.7=py_0\n  - xlwt=1.3.0=py37_0\n  - xz=5.2.5=h7b6447c_0\n  - yaml=0.2.5=h7b6447c_0\n  - yapf=0.30.0=py_0\n  - zeromq=4.3.3=he6710b0_3\n  - zict=2.0.0=py_0\n  - zipp=3.4.0=pyhd3eb1b0_0\n  - zlib=1.2.11=h7b6447c_3\n  - zope=1.0=py37_1\n  - zope.event=4.5.0=py37_0\n  - zope.interface=5.1.2=py37h7b6447c_0\n  - zstd=1.4.5=h9ceee32_0\n  - pip:\n    - absl-py==0.12.0\n    - aiohttp==3.7.4.post0\n    - async-timeout==3.0.1\n    - cachetools==4.2.1\n    - cshogi==0.2.4\n    - google-auth==1.28.1\n    - google-auth-oauthlib==0.4.4\n    - grpcio==1.37.0\n    - markdown==3.3.4\n    - multidict==5.1.0\n    - oauthlib==3.1.0\n    - omegaconf==2.0.6\n    - protobuf==3.15.8\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pytorch-lightning==1.2.7\n    - requests-oauthlib==1.3.0\n    - rsa==4.7.2\n    - sacremoses==0.0.44\n    - tensorboard==2.4.1\n    - tensorboard-plugin-wit==1.8.0\n    - tokenizers==0.10.2\n    - torchmetrics==0.2.0\n    - transformers==4.5.1\n    - yarl==1.6.3\nprefix: /home/charmer/.pyenv/versions/anaconda3-2020.07/envs/bert-mcts\n"
  },
  {
    "path": "requirements.txt",
    "content": "--find-links https://download.pytorch.org/whl/torch_stable.html\ntorch>=1.6\npytorch-lightning==1.2.7\ntransformers>=4.5\ncshogi\nomegaconf"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\nfrom torch.utils.cpp_extension import BuildExtension\n\nif __name__ == '__main__':\n    setup(\n        name='bert-mcts',\n        description='BERT-driven Shogi AI',\n        author='Hiroki Taniai',\n        author_email='charmer.popopo@gmail.com',\n        packages=find_packages(include=['src']),\n        cmdclass={'build_ext': BuildExtension},\n    )\n"
  },
  {
    "path": "src/data/__init__.py",
    "content": "\n"
  },
  {
    "path": "src/data/mlm.py",
    "content": "import numpy as np\r\nimport torch\r\nfrom torch.utils.data import Dataset\r\n\r\nfrom src.utils.shogi import pieces_list\r\n\r\n\r\n# Masked Language Model\r\nclass MLMDataset(Dataset):\r\n    def __init__(self, data):\r\n        self.seqs = data['seq']\r\n        self.mask_token_id = 32  # 32は駒が割り振られていないid\r\n        assert self.mask_token_id not in np.array(pieces_list)\r\n\r\n    def __len__(self):\r\n        return len(self.seqs)\r\n\r\n    def __getitem__(self, idx):\r\n        inputs = np.array(self.seqs[idx])\r\n        labels = inputs.copy()\r\n\r\n        # 予想対象\r\n        masked_indices = np.random.random(labels.shape) < 0.15\r\n        labels[~masked_indices] = -100\r\n\r\n        # 80%はマスクトークンに\r\n        indices_replaced = (np.random.random(labels.shape) < 0.8) & masked_indices\r\n        inputs[indices_replaced] = self.mask_token_id\r\n\r\n        # 10%はランダムに置き換え\r\n        indices_random = (np.random.random(labels.shape) < 0.5) & masked_indices & ~indices_replaced\r\n        random_words = np.random.choice(pieces_list, labels.shape)\r\n        inputs[indices_random] = random_words[indices_random]\r\n\r\n        # 残り10%はそのままのものが残る\r\n\r\n        ret_dict = {'input_ids': torch.tensor(inputs, dtype=torch.long),\r\n                    'labels': torch.tensor(labels, dtype=torch.long)}\r\n        return ret_dict\r\n"
  },
  {
    "path": "src/data/policy_value.py",
    "content": "import torch\r\nfrom torch.utils.data import Dataset\r\n\r\n\r\nclass PolicyValueDataset(Dataset):\r\n    def __init__(self, data):\r\n        self.data = data\r\n\r\n    def __len__(self):\r\n        return len(self.data)\r\n\r\n    def __getitem__(self, idx):\r\n        dt = self.data[idx]\r\n\r\n        ret_dict = {'input_ids': torch.tensor(dt['seq'], dtype=torch.long),\r\n                    'labels': torch.tensor(dt['label'], dtype=torch.long),\r\n                    'values': torch.tensor(dt['value'], dtype=torch.float),\r\n                    'result': torch.tensor(dt['result'], dtype=torch.long)}\r\n        return ret_dict\r\n"
  },
  {
    "path": "src/features/__init__.py",
    "content": "\n"
  },
  {
    "path": "src/features/common.py",
    "content": "import cshogi\n\nfrom src.utils.shogi import reverse_piece_fn\n\n\ndef get_seq_from_board(board):\n    bp, wp = board.pieces_in_hand\n    if board.turn == cshogi.BLACK:\n        seq = board.pieces + bp + wp\n    else:\n        # 駒の順番を逆にして、駒を先後反転させる。持ち駒は逆にする。\n        seq = reverse_piece_fn(board.pieces[::-1]).tolist() + wp + bp\n    return seq\n"
  },
  {
    "path": "src/features/policy_value.py",
    "content": "import cshogi\r\nimport numpy as np\r\nfrom cshogi import move_drop_hand_piece, move_from, move_is_drop, move_is_promotion, move_to\r\n\r\nfrom src.features.common import get_seq_from_board\r\nfrom src.utils.shogi import (DOWN, DOWN_LEFT, DOWN_RIGHT, LEFT, MOVE_DIRECTION, MOVE_DIRECTION_PROMOTED, RIGHT, UP,\r\n                             UP2_LEFT, UP2_RIGHT, UP_LEFT, UP_RIGHT)\r\n\r\nboard = cshogi.Board()\r\n\r\n\r\ndef get_move_label(move, color):\r\n    if not move_is_drop(move):\r\n        from_sq = move_from(move)\r\n        to_sq = move_to(move)\r\n        if color == cshogi.WHITE:\r\n            to_sq = 80 - to_sq\r\n            from_sq = 80 - from_sq\r\n\r\n        # file: 筋, rank: 段\r\n        from_file, from_rank = divmod(from_sq, 9)\r\n        to_file, to_rank = divmod(to_sq, 9)\r\n        dir_file = to_file - from_file\r\n        dir_rank = to_rank - from_rank\r\n        if dir_rank < 0 and dir_file == 0:\r\n            move_direction = UP\r\n        elif dir_rank == -2 and dir_file == -1:\r\n            move_direction = UP2_RIGHT\r\n        elif dir_rank == -2 and dir_file == 1:\r\n            move_direction = UP2_LEFT\r\n        elif dir_rank < 0 and dir_file < 0:\r\n            move_direction = UP_RIGHT\r\n        elif dir_rank < 0 and dir_file > 0:\r\n            move_direction = UP_LEFT\r\n        elif dir_rank == 0 and dir_file < 0:\r\n            move_direction = RIGHT\r\n        elif dir_rank == 0 and dir_file > 0:\r\n            move_direction = LEFT\r\n        elif dir_rank > 0 and dir_file == 0:\r\n            move_direction = DOWN\r\n        elif dir_rank > 0 and dir_file < 0:\r\n            move_direction = DOWN_RIGHT\r\n        elif dir_rank > 0 and dir_file > 0:\r\n            move_direction = DOWN_LEFT\r\n        else:\r\n            raise RuntimeError\r\n\r\n        # promote\r\n        if move_is_promotion(move):\r\n            move_direction = MOVE_DIRECTION_PROMOTED[move_direction]\r\n\r\n    else:\r\n        # 持ち駒\r\n        move_direction = len(MOVE_DIRECTION) + move_drop_hand_piece(move) - 1\r\n        to_sq = move_to(move)\r\n        if color == cshogi.WHITE:\r\n            to_sq = 80 - to_sq\r\n\r\n    # labelのmaxは27*81-1=2186\r\n    move_label = move_direction * 81 + to_sq\r\n    return move_label\r\n\r\n\r\ndef get_result(result, color):\r\n    # 引き分け\r\n    if result == 0:\r\n        return 0.5\r\n    # 手番が勝ち\r\n    elif ((result == 1) and (color == cshogi.BLACK)) or ((result == 2) and (color == cshogi.WHITE)):\r\n        return 1\r\n    else:\r\n        return 0\r\n\r\n\r\ndef get_policy_value_label(hcpe):\r\n    board.reset()\r\n    board.set_hcp(hcpe['hcp'])\r\n\r\n    seq = get_seq_from_board(board)\r\n    label = get_move_label(hcpe['bestMove16'], board.turn)\r\n    value = 1 / (1 + np.exp(-hcpe['eval'] * 0.0013226))\r\n    result = get_result(hcpe['gameResult'], board.turn)\r\n\r\n    return seq, label, value, result\r\n\r\n\r\ndef get_policy_value_label_from_moves(moves):\r\n    n = np.random.randint(4, len(moves)-1)\r\n    board.reset()\r\n    for move in moves[:n]:\r\n        board.push(move)\r\n    seq = get_seq_from_board(board)\r\n    label = get_move_label(moves[n], board.turn)\r\n    value = 0.5\r\n    result = 0.5\r\n    return seq, label, value, result\r\n\r\n\r\ndef get_moves_from_lines(line):\r\n    board.reset()\r\n    moves = []\r\n    for move_usi in line.split()[2:]:\r\n        move = board.move_from_usi(move_usi)\r\n        board.push(move)\r\n        moves.append(move)\r\n    return moves\r\n"
  },
  {
    "path": "src/model/__init__.py",
    "content": "\n"
  },
  {
    "path": "src/model/bert.py",
    "content": "import torch.nn as nn\r\nfrom transformers import BertConfig, BertForMaskedLM, BertModel\r\n\r\nfrom src.utils.shogi import pieces_list\r\n\r\nconfig = {\r\n    'vocab_size': len(pieces_list) + 4,  # MASK_TOKEN_ID, MASK, CLS, SEP\r\n    'hidden_size': 768,\r\n    'num_hidden_layers': 12,\r\n    'num_attention_heads': 12,\r\n    'intermediate_size': 3072,  # hidden_size * 4が目安\r\n    'hidden_act': 'gelu',\r\n    'hidden_dropout_prob': 0.1,\r\n    'attention_probs_dropout_prob': 0.1,\r\n    'max_position_embeddings': 512,  # 95(=81(マス目)+7(先手持駒)+7(後手持駒))でいいかも\r\n    'type_vocab_size': 1,  # 対の文章を入れない。つまりtoken_type_embeddingsは完全に無駄になっている。\r\n    'initializer_range': 0.02,\r\n}\r\nconfig = BertConfig.from_dict(config)\r\n\r\n\r\nclass BertMLM(nn.Module):\r\n    def __init__(self, model_dir=None):\r\n        super().__init__()\r\n        if model_dir is None:\r\n            self.bert = BertForMaskedLM(config)\r\n        else:\r\n            self.bert = BertForMaskedLM.from_pretrained(model_dir)\r\n\r\n    def forward(self, input_ids, labels):\r\n        return self.bert(input_ids=input_ids, labels=labels)\r\n\r\n\r\nclass BertPolicyValue(nn.Module):\r\n    def __init__(self, model_dir=None):\r\n        super().__init__()\r\n        if model_dir is None:\r\n            self.bert = BertModel(config)\r\n        else:\r\n            self.bert = BertModel.from_pretrained(model_dir)\r\n\r\n        self.policy_head = nn.Sequential(\r\n            nn.Linear(768, 768 * 2),\r\n            nn.Tanh(),\r\n            nn.Linear(768 * 2, 9 * 9 * 27)\r\n        )\r\n\r\n        self.value_head = nn.Sequential(\r\n            nn.Linear(768, 768 * 2),\r\n            nn.Tanh(),\r\n            nn.Linear(768 * 2, 1),\r\n            nn.Sigmoid()\r\n        )\r\n\r\n        self.loss_policy_fn = nn.CrossEntropyLoss()\r\n        self.loss_value_fn = nn.MSELoss()\r\n\r\n    def forward(self, input_ids, labels=None):\r\n        features = self.bert(input_ids=input_ids)['last_hidden_state']\r\n        policy = self.policy_head(features).mean(axis=1)\r\n        value = self.value_head(features).mean(axis=1).squeeze(1)\r\n        if labels is None:\r\n            return {'policy': policy, 'value': value}\r\n        else:\r\n            loss_policy = self.loss_policy_fn(policy, labels['labels'])\r\n            loss_value = self.loss_value_fn(value, labels['values'])\r\n            loss = loss_policy + loss_value\r\n            return {'loss_policy': loss_policy, 'loss_value': loss_value, 'loss': loss}\r\n"
  },
  {
    "path": "src/pl_modules/__init__.py",
    "content": "from .mlm import MLMModule, MLMDataModule\nfrom .policy_value import PolicyValueModule, PolicyValueDataModule\n\n\ndef get_pl_modules(cfg):\n    if cfg.model_type == 'MLM':\n        return MLMModule(cfg), MLMDataModule(cfg)\n    elif cfg.model_type == 'PolicyValue':\n        return PolicyValueModule(cfg), PolicyValueDataModule(cfg)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "src/pl_modules/mlm.py",
    "content": "from pathlib import Path\n\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoader\nfrom transformers import AdamW\n\nfrom src.data.mlm import MLMDataset\nfrom src.model.bert import BertMLM\n\n\nclass MLMModule(pl.LightningModule):\n\n    def __init__(self, hparams):\n        super().__init__()\n        self.hparams = hparams\n        self.model = BertMLM(hparams['model_dir'])\n\n    def forward(self, batch):\n        input_ids = batch['input_ids']\n        labels = batch['labels']\n        return self.model(input_ids=input_ids, labels=labels)\n\n    def training_step(self, batch, batch_idx):\n        outputs = self(batch)\n        loss = outputs[0]\n        self.log('loss', loss)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        outputs = self(batch)\n        loss = outputs[0].detach().cpu().numpy()\n        return {'loss': loss}\n\n    def validation_epoch_end(self, outputs):\n        val_loss = np.mean([out['loss'] for out in outputs])\n        self.log('steps', self.global_step)\n        self.log('val_loss', val_loss)\n\n    def configure_optimizers(self):\n        return AdamW(self.parameters(), lr=5e-5)\n\n\nclass MLMDataModule(pl.LightningDataModule):\n    def __init__(self, cfg):\n        super().__init__()\n        self.cfg = cfg\n\n    def setup(self, stage=None):\n        dataset_dir = Path(self.cfg.dataset_dir)\n        train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True)\n        valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True)\n        self.train_dataset = MLMDataset(train_data)\n        self.val_dataset = MLMDataset(valid_data)\n\n    def train_dataloader(self):\n        return DataLoader(self.train_dataset, **self.cfg.train_loader)\n\n    def val_dataloader(self):\n        return DataLoader(self.val_dataset, **self.cfg.val_loader)\n"
  },
  {
    "path": "src/pl_modules/policy_value.py",
    "content": "from pathlib import Path\n\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoader\nfrom transformers import AdamW\n\nfrom src.data.policy_value import PolicyValueDataset\nfrom src.model.bert import BertPolicyValue\n\n\nclass PolicyValueModule(pl.LightningModule):\n    def __init__(self, hparams):\n        super().__init__()\n        self.hparams = hparams\n        self.model = BertPolicyValue(hparams['model_dir'])\n\n    def forward(self, input_ids, labels=None):\n        output = self.model(input_ids, labels)\n        return output\n\n    def training_step(self, batch, batch_idx):\n        input_ids = batch.pop('input_ids')\n        output = self(input_ids, batch)\n        return {'loss': output['loss']}\n\n    def validation_step(self, batch, batch_idx):\n        input_ids = batch.pop('input_ids')\n        output = self(input_ids, batch)\n        for k, v in output.items():\n            output[k] = v.detach().cpu().numpy()\n        return output\n\n    def validation_epoch_end(self, outputs):\n        val_loss = np.mean([out['loss'] for out in outputs])\n        val_loss_policy = np.mean([out['loss_policy'] for out in outputs])\n        val_loss_value = np.mean([out['loss_value'] for out in outputs])\n        self.log('val_loss', val_loss)\n        self.log('val_loss_policy', val_loss_policy)\n        self.log('val_loss_value', val_loss_value)\n\n    def configure_optimizers(self):\n        return AdamW(self.parameters(), lr=5e-5)\n\n\nclass PolicyValueDataModule(pl.LightningDataModule):\n    def __init__(self, cfg):\n        super().__init__()\n        self.cfg = cfg\n\n    def setup(self, stage=None):\n        dataset_dir = Path(self.cfg.dataset_dir)\n        train_data = np.load(dataset_dir / 'train.npy', allow_pickle=True)\n        print('Load train data')\n        valid_data = np.load(dataset_dir / 'val.npy', allow_pickle=True)\n        print('Load val data')\n        self.train_dataset = PolicyValueDataset(train_data)\n        self.val_dataset = PolicyValueDataset(valid_data)\n\n    def train_dataloader(self):\n        return DataLoader(self.train_dataset, **self.cfg.train_loader)\n\n    def val_dataloader(self):\n        return DataLoader(self.val_dataset, **self.cfg.val_loader)\n\n"
  },
  {
    "path": "src/player/__init__.py",
    "content": "\n"
  },
  {
    "path": "src/player/base_player.py",
    "content": "import cshogi\n\n\nclass BasePlayer:\n    def __init__(self):\n        self.board = cshogi.Board()\n\n    def usi(self):\n        print('id name bert_player')\n        print('usiok')\n\n    def usinewgame(self):\n        pass\n\n    def setoption(self, option):\n        pass\n\n    def isready(self):\n        pass\n\n    def position(self, moves):\n        if moves[0] == 'startpos':\n            self.board.reset()\n            for move in moves[2:]:\n                self.board.push_usi(move)\n        elif moves[0] == 'sfen':\n            self.board.set_sfen(' '.join(moves[1:]))\n        # for debug\n        print(self.board.sfen())\n\n    def go(self):\n        pass\n\n    def quit(self):\n        pass\n"
  },
  {
    "path": "src/player/mcts_player.py",
    "content": "import time\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nimport cshogi\nimport numpy as np\nimport torch\n\nfrom src.features.common import get_seq_from_board\nfrom src.features.policy_value import get_move_label\nfrom src.pl_modules.policy_value import PolicyValueModule\nfrom src.player.base_player import BasePlayer\nfrom src.player.usi import usi\nfrom src.uct.uct_node import NOT_EXPANDED, NodeHash, UCT_HASH_SIZE, UctNode\nfrom src.utils.misc import boltzmann\n\n\nclass MCTSPlayer(BasePlayer):\n    def __init__(self, ckpt_path, playout_halt=1000, temperature=1, resign_threshold=0.01, c_puct=1):\n        super().__init__()\n        self.ckpt_path = ckpt_path\n        self.playout_halt = playout_halt\n        self.temperature = temperature\n        self.resign_threshold = resign_threshold\n        self.c_puct = c_puct\n\n        self.model = None\n        self.node_hash = NodeHash()\n        self.uct_nodes = [UctNode() for _ in range(UCT_HASH_SIZE)]\n        self.current_n_idx = None\n        self.playout_count = 0\n\n    def usi(self):\n        print('id name BERT-MCTS')\n        print('usiok')\n\n    def isready(self):\n        if self.model is None:\n            self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model\n            self.model.cuda()\n            self.model.eval()\n        self.node_hash.initialize()\n        print('readyok')\n\n    def go(self):\n        if self.board.is_game_over():\n            print('bestmove resign')\n            return\n\n        # 探索情報をクリア\n        self.playout_count = 0\n\n        # 古いハッシュを削除\n        self.node_hash.delete_old_hash(self.board, self.uct_nodes)\n\n        # 探索開始時刻\n        begin_time = time.time()\n\n        # ルートノードの展開\n        self.current_n_idx = self.expand_node()\n\n        # 候補手が１つの場合はその手を返す\n        current_node = self.uct_nodes[self.current_n_idx]\n        child_moves = current_node.child_moves\n        child_num = len(child_moves)\n        if child_num == 1:\n            print('bestmove', cshogi.move_to_usi(child_moves[0]))\n            return\n\n        def get_bestmove_and_print_info():\n            # 探索にかかった時間を求める\n            finish_time = time.time() - begin_time\n\n            if self.board.move_number < 5:\n                selected_index = np.random.choice(np.arange(child_num), p=current_node.policy)\n            else:\n                # 訪問回数最大の手を選択する\n                selected_index = np.argmax(current_node.child_moves_count)\n\n            # 選択したノードの訪問回数0ならポリシーの値\n            if current_node.child_moves_count[selected_index] == 0:\n                best_wp = current_node.policy[selected_index]\n            # それ以外なら勝率の平均を出す\n            else:\n                best_wp = current_node.child_value_sum[selected_index] / current_node.child_moves_count[selected_index]\n\n            # 閾値以下なら投了\n            if best_wp < self.resign_threshold:\n                bestmove = 'resign'\n            else:\n                bestmove = cshogi.move_to_usi(child_moves[selected_index])\n\n            # valueを評価値のスケールに変換\n            if best_wp == 1:\n                cp = 30000\n            elif best_wp == 0:\n                cp = -30000\n            else:\n                cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762)\n\n            nps = int(current_node.move_count / finish_time)\n            time_secs = int(finish_time * 1000)\n            nodes = current_node.move_count\n            hashfull = self.node_hash.get_usage_rate() * 100\n            print(f'info score cp {cp} hashfull {hashfull:.2f} time {time_secs} nodes {nodes} nps {nps} pv {bestmove}')\n            return bestmove\n\n        # プレイアウトを繰り返す\n        # 探索回数が閾値を超える、または探索が打ち切られたらループを抜ける\n        while self.playout_count < self.playout_halt:\n            self.playout_count += 1\n            self.uct_search(self.current_n_idx)\n            # 10回に1回 読みの状況を出力\n            if (self.playout_count+1) % 10 == 0:\n                get_bestmove_and_print_info()\n            # 探索を打ち切るか確認\n            if self.interruption_check() or not self.node_hash.enough_size:\n                break\n\n        bestmove = get_bestmove_and_print_info()\n        print('bestmove', bestmove)\n\n    def expand_node(self):\n        current_hash = self.board.zobrist_hash()\n        current_turn = self.board.turn\n        current_move_number = self.board.move_number\n        n_idx = self.node_hash.find_same_hash_index(current_hash, current_turn, current_move_number)\n\n        # 合流先が検知できれば、それを返す\n        if n_idx != UCT_HASH_SIZE:\n            return n_idx\n\n        # 空のインデックスを探す\n        n_idx = self.node_hash.search_empty_index(current_hash, current_turn, current_move_number)\n\n        # 現在のノードの初期化\n        current_node = self.uct_nodes[n_idx]\n        current_node.reset()\n\n        # 候補手の展開\n        current_node.child_moves = [move for move in self.board.legal_moves]\n        child_num = len(current_node.child_moves)\n        current_node.child_n_indices = [NOT_EXPANDED for _ in range(child_num)]\n        current_node.child_moves_count = np.zeros(child_num, dtype=np.int32)\n        current_node.child_value_sum = np.zeros(child_num, dtype=np.float32)\n\n        # ノードを評価\n        self.eval_node(n_idx)\n\n        return n_idx\n\n    def eval_node(self, n_idx):\n        current_node = self.uct_nodes[n_idx]\n        child_moves = current_node.child_moves\n        child_num = len(current_node.child_moves)\n        if child_num == 0:\n            # 指す手がない＝負け\n            current_node.value = 0\n            current_node.evaled = True\n        else:\n            # 現在の局面における方策と価値\n            seq = get_seq_from_board(self.board)\n            input_ids = torch.tensor([seq], dtype=torch.long).cuda()\n            with torch.no_grad():\n                output = self.model(input_ids)\n                value = output['value'].detach().cpu().numpy()[0]\n                policy_logits = output['policy'].detach().cpu().numpy()[0]\n\n            # 合法手でフィルターする\n            legal_move_labels = []\n            for move in child_moves:\n                legal_move_labels.append(get_move_label(move, self.board.turn))\n\n            # Boltzmann\n            policy = boltzmann(policy_logits[legal_move_labels], self.temperature)\n\n            # ノードの値を更新\n            current_node.policy = policy\n            current_node.value = value\n            current_node.evaled = True\n\n    def uct_search(self, n_idx):\n        current_node = self.uct_nodes[n_idx]\n        child_moves = current_node.child_moves\n        child_n_indices = current_node.child_n_indices\n        child_num = len(child_moves)\n\n        # 詰みは負け->元ノードから見たら勝ち\n        if child_num == 0:\n            return 1\n\n        # PUCTアルゴリズムによるUCB値最大の手\n        next_c_idx = self.select_max_ucb_child(n_idx)\n        next_move = child_moves[next_c_idx]\n        next_n_idx = child_n_indices[next_c_idx]\n\n        # 選んだ手を着手\n        self.board.push(next_move)\n\n        if next_n_idx == NOT_EXPANDED:\n            # 選択した手に対応するコードが未展開なら展開\n            # ノードの展開（ノード展開処理の中でノードを評価する）\n            next_n_idx = self.expand_node()\n            child_n_indices[next_c_idx] = next_n_idx\n            child_node = self.uct_nodes[next_n_idx]\n            result = 1 - child_node.value\n        else:\n            # 展開済みなら一手深く読む\n            result = self.uct_search(next_n_idx)\n\n        # バックアップ\n        # 探索結果の反映\n        current_node.move_count += 1\n        current_node.child_value_sum[next_c_idx] += result\n        current_node.child_moves_count[next_c_idx] += 1\n        # 手を戻す\n        self.board.pop(next_move)\n\n        return 1 - result\n\n    def select_max_ucb_child(self, c_idx):\n        current_node = self.uct_nodes[c_idx]\n        child_wins_count = current_node.child_value_sum\n        child_moves_count = current_node.child_moves_count\n        child_num = len(child_moves_count)\n\n        # child_move_countが0の場所のqは0.5で埋める\n        q = np.divide(child_wins_count, child_moves_count, out=np.repeat(0.5, child_num), where=child_moves_count != 0)\n        u = np.sqrt(current_node.move_count) / (1 + child_moves_count)\n        ucb = q + self.c_puct * current_node.policy * u\n\n        return np.argmax(ucb)\n\n    # 探索を打ち切るか確認\n    def interruption_check(self):\n        child_move_count = self.uct_nodes[self.current_n_idx].child_moves_count\n        rest = self.playout_halt - self.playout_count\n\n        # 探索回数が最も多い手と次に多いてを求める\n        second, first = child_move_count[np.argpartition(child_move_count, -2)[-2:]]\n\n        # 残りの探索を全て次善手に費やしても最善手を超えられない場合は探索を打ち切る\n        if first - second > rest:\n            return True\n        else:\n            return False\n\n\ndef argparse():\n    parser = ArgumentParser()\n    parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt')\n    args, _ = parser.parse_known_args()\n    print('Command Line Args:')\n    print(args)\n    return args\n\n\ndef main(args):\n    ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve()\n    player = MCTSPlayer(ckpt_path)\n    usi(player)\n\n\nif __name__ == '__main__':\n    args = argparse()\n    main(args)\n"
  },
  {
    "path": "src/player/policy_player.py",
    "content": "from argparse import ArgumentParser\nfrom pathlib import Path\nimport numpy as np\n\nimport cshogi\nimport torch\nimport torch.nn.functional as F\n\nfrom src.features.common import get_seq_from_board\nfrom src.features.policy_value import get_move_label\nfrom src.pl_modules.policy_value import PolicyValueModule\nfrom src.player.base_player import BasePlayer\nfrom src.player.usi import usi\nfrom src.utils.misc import greedy\n\n\nclass PolicyPlayer(BasePlayer):\n    def __init__(self, ckpt_path):\n        super().__init__()\n        self.ckpt_path = ckpt_path\n        self.model = None\n\n    def usi(self):\n        print('id name BERT-Policy')\n        print('usiok')\n\n    def isready(self):\n        if self.model is None:\n            self.model = PolicyValueModule.load_from_checkpoint(self.ckpt_path).model\n            self.model.cuda()\n            self.model.eval()\n        print('readyok')\n\n    def go(self):\n        if self.board.is_game_over():\n            print('bestmove resign')\n            return\n\n        seq = get_seq_from_board(self.board)\n        input_ids = torch.tensor([seq], dtype=torch.long).cuda()\n\n        with torch.no_grad():\n            output = self.model(input_ids)\n            policy = F.softmax(output['policy'], dim=1).detach().cpu().numpy()[0]\n\n        # 全ての合法手について\n        legal_moves = []\n        legal_policy = []\n        for move in self.board.legal_moves:\n            # ラベルに変換\n            label = get_move_label(move, self.board.turn)\n            # 合法手とその指し手の確率（logits）を格納\n            legal_moves.append(move)\n            legal_policy.append(policy[label])\n\n        selected_index = greedy(legal_policy)\n        bestmove = cshogi.move_to_usi(legal_moves[selected_index])\n        best_wp = legal_policy[selected_index]\n        # valueを評価値のスケールに変換\n        if best_wp == 1:\n            cp = 30000\n        elif best_wp == 0:\n            cp = -30000\n        else:\n            cp = int(-np.log(1 / best_wp - 1) * 756.0864962951762)\n\n        print(f'info score cp {cp} pv {bestmove}')\n\n        print('bestmove', bestmove)\n\n\ndef argparse():\n    parser = ArgumentParser()\n    parser.add_argument('--ckpt_path', type=str, default='./work_dirs/last.ckpt')\n    args, _ = parser.parse_known_args()\n    print('Command Line Args:')\n    print(args)\n    return args\n\n\ndef main(args):\n    ckpt_path = (Path(__file__).parent.parent.parent / args.ckpt_path).resolve()\n    player = PolicyPlayer(ckpt_path)\n    usi(player)\n\n\nif __name__ == '__main__':\n    args = argparse()\n    main(args)\n"
  },
  {
    "path": "src/player/usi.py",
    "content": "from src.player.base_player import BasePlayer\n\n\ndef usi(player: BasePlayer):\n    while True:\n        cmd_line = input()\n        cmd = cmd_line.split(' ', 1)\n        cmd = [c.rstrip() for c in cmd]\n\n        if cmd[0] == 'usi':\n            player.usi()\n        elif cmd[0] == 'setoption':\n            option = cmd[1].split(' ')\n            player.setoption(option)\n        elif cmd[0] == 'isready':\n            player.isready()\n        elif cmd[0] == 'usinewgame':\n            player.usinewgame()\n        elif cmd[0] == 'position':\n            moves = cmd[1].split(' ')\n            player.position(moves)\n        elif cmd[0] == 'go':\n            player.go()\n        elif cmd[0] == 'quit':\n            player.quit()\n            break\n"
  },
  {
    "path": "src/uct/__init__.py",
    "content": "\n"
  },
  {
    "path": "src/uct/uct_node.py",
    "content": "# ノードの上限値\nUCT_HASH_SIZE = 4096\n# 未展開のノードのインデックス\nNOT_EXPANDED = -1\n\n\n# ゾブリストハッシュ値をUCT_HASH_SIZEに圧縮\ndef hash_to_index(zhash):\n    return ((zhash & 0xffffffff) ^ ((zhash >> 32) & 0xffffffff)) & (UCT_HASH_SIZE - 1)\n\n\nclass NodeHashEntry:\n    def __init__(self):\n        self.hash = 0  # ゾブリストハッシュの値\n        self.color = 0  # 手番\n        self.moves = 0  # ゲーム開始からの手数\n        self.flag = False  # 使用中か識別するフラグ\n\n    def reset(self):\n        self.hash = 0\n        self.color = 0\n        self.moves = 0\n        self.flag = False\n\n\nclass NodeHash:\n    def __init__(self):\n        self.used = 0\n        self.enough_size = True\n        self.node_hash = None\n\n    def initialize(self):\n        self.used = 0\n        self.enough_size = True\n\n        if self.node_hash is None:\n            self.node_hash = [NodeHashEntry() for _ in range(UCT_HASH_SIZE)]\n        else:\n            for i in range(UCT_HASH_SIZE):\n                self.node_hash[i].reset()\n\n    # 未使用のインデックスを探して返す\n    def search_empty_index(self, zhash, color, moves):\n        key = hash_to_index(zhash)\n        i = key\n\n        while True:\n            if not self.node_hash[i].flag:\n                self.node_hash[i].hash = zhash\n                self.node_hash[i].color = color\n                self.node_hash[i].moves = moves\n                self.node_hash[i].flag = True\n                self.used += 1\n                if self.get_usage_rate() > 0.9:\n                    self.enough_size = False\n                return i\n            i += 1\n            if i >= UCT_HASH_SIZE:\n                i = 0\n            # もとに戻ってくる\n            if i == key:\n                return UCT_HASH_SIZE\n\n    # ハッシュ値に対応するインデックスを返す\n    def find_same_hash_index(self, zhash, color, moves):\n        key = hash_to_index(zhash)\n        i = key\n\n        while True:\n            # もろもろの属性があっていたらiを返す\n            if self.node_hash[i].flag and self.node_hash[i].hash == zhash and self.node_hash[i].color == color and \\\n                    self.node_hash[i].moves == moves:\n                return i\n            else:\n                return UCT_HASH_SIZE\n\n    # 使用中のノードを残す\n    def save_used_hash(self, board, uct_nodes, n_idx):\n        self.node_hash[n_idx].flag = True\n        self.used += 1\n\n        current_node = uct_nodes[n_idx]\n        child_n_indices = current_node.child_n_indices\n        child_moves = current_node.child_moves\n        child_num = len(child_moves)\n        for i in range(child_num):\n            if child_n_indices[i] != NOT_EXPANDED and not self.node_hash[child_n_indices[i]].flag:\n                board.push(child_moves[i])\n                self.save_used_hash(board, uct_nodes, child_n_indices[i])\n                board.pop(child_moves[i])\n\n    # 古いハッシュを削除\n    def delete_old_hash(self, board, uct_node):\n        # 現在の局面をルートとする局面以外を削除する\n        n_idx = self.find_same_hash_index(board.zobrist_hash(), board.turn, board.move_number)\n\n        self.used = 0\n        for i in range(UCT_HASH_SIZE):\n            self.node_hash[i].reset()\n\n        if n_idx != UCT_HASH_SIZE:\n            self.save_used_hash(board, uct_node, n_idx)\n\n        self.enough_size = True\n\n    def get_usage_rate(self):\n        return self.used / UCT_HASH_SIZE\n\n\nclass UctNode:\n    def __init__(self):\n        self.evaled = False  # 評価済みフラグ\n        self.move_count = 0  # ノードの訪問回数\n        self.value = 0  # ノードの価値ネットワークの評価（予測勝率）\n        self.policy = None  # 正規化した方策ネットワークの出力（子ノード分の長さを持つ）\n        self.child_moves = None  # 子ノードの指し手\n        self.child_n_indices = None  # 子ノードのインデックス\n        self.child_moves_count = None  # 子ノードの訪問回数 (UCB用)\n        self.child_value_sum = None  # 子ノードのvalueの合計 (UCB用)\n\n    def reset(self):\n        self.evaled = False\n        self.move_count = 0\n        self.value = 0\n        self.policy = None\n        self.child_moves = None\n        self.child_n_indices = None\n        self.child_moves_count = None\n        self.child_value_sum = None\n"
  },
  {
    "path": "src/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/utils/hcpe.py",
    "content": "import cshogi\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom src.features.policy_value import get_policy_value_label\n\n\ndef get_data_from_hcpe(hcpes):\n    data = []\n    for hcpe in tqdm(hcpes):\n        data.append(get_policy_value_label(hcpe))\n\n    data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')])\n    return data\n\n\ndef load_hcpes(hcpe_paths):\n    hcpe_list = []\n    for hcpe_path in hcpe_paths:\n        hcpe_list.append(np.fromfile(hcpe_path, dtype=cshogi.HuffmanCodedPosAndEval))\n    hcpes = np.concatenate(hcpe_list)\n    return hcpes\n"
  },
  {
    "path": "src/utils/misc.py",
    "content": "import numpy as np\n\n\ndef greedy(logits):\n    return np.asarray(logits).argmax()\n\n\ndef boltzmann(logits, temperature):\n    logits /= temperature\n    logits -= logits.max()\n    probabilities = np.exp(logits)\n    probabilities /= probabilities.sum()\n    return probabilities\n\n"
  },
  {
    "path": "src/utils/sfen.py",
    "content": "from pathlib import Path\n\nimport numpy as np\n\nfrom src.features.policy_value import get_moves_from_lines, get_policy_value_label_from_moves\n\n\ndef get_gokaku_sfen_paths(base_dir: Path, max_num_of_moves=40):\n    assert base_dir.name == 'ShogiGokakuKyokumen'\n    sfen_paths = []\n    # TODO: 正規表現にする\n    for sfen_path in sorted(list(base_dir.glob('*/*/*.sfen')) + list(base_dir.glob('*/*/*/*.sfen'))):\n        # 互角で40手以下の棋譜\n        if '互角' in sfen_path.stem and (int(sfen_path.stem.split('手目')[0][-3:]) <= max_num_of_moves):\n            sfen_paths.append(sfen_path)\n    return sfen_paths\n\n\ndef get_data_from_sfen(sfens):\n    # sfen形式のデータは棋譜の途中を使って学習データを作成する\n    data = []\n    for sfen in sfens:\n        moves = get_moves_from_lines(sfen)\n        # 手数が6手以上のものを使う\n        if len(moves) < 6:\n            continue\n        # 30手なら6局面を抽出\n        for _ in range(max(1, len(moves) // 5)):\n            dt = get_policy_value_label_from_moves(moves)\n            data.append(dt)\n\n    data = np.array(data, dtype=[('seq', 'O'), ('label', 'u2'), ('value', 'f4'), ('result', 'f2')])\n    return data\n\n\ndef load_sfen(sfen_path):\n    with open(sfen_path) as f:\n        sfens = [l.rstrip() for l in f.readlines()]\n    return sfens\n\n\ndef load_sfens(sfen_paths):\n    sfens = []\n    for kifu_path in sfen_paths:\n        sfens.extend(load_sfen(kifu_path))\n    return sfens\n"
  },
  {
    "path": "src/utils/shogi.py",
    "content": "import numpy as np\r\n\r\n# 移動の定数\r\nMOVE_DIRECTION = [\r\n    UP, UP_LEFT, UP_RIGHT, LEFT, RIGHT, DOWN, DOWN_LEFT, DOWN_RIGHT, UP2_LEFT, UP2_RIGHT,\r\n    UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE,\r\n    DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE\r\n] = range(20)\r\n\r\n# 成り変換テーブル\r\nMOVE_DIRECTION_PROMOTED = [\r\n    UP_PROMOTE, UP_LEFT_PROMOTE, UP_RIGHT_PROMOTE, LEFT_PROMOTE, RIGHT_PROMOTE, DOWN_PROMOTE, DOWN_LEFT_PROMOTE,\r\n    DOWN_RIGHT_PROMOTE, UP2_LEFT_PROMOTE, UP2_RIGHT_PROMOTE\r\n]\r\n\r\n# 指し手を表すラベルの数\r\nMOVE_DIRECTION_LABEL_NUM = len(MOVE_DIRECTION) + 7  # 7は持ち駒の種類\r\n\r\n# 先手駒と後手駒を逆に変換する辞書(reverse_piece_fn)の用意\r\nbw_dict = {k: k + 16 for k in range(1, 15)}  # 1~14に自駒, 17~30に敵駒\r\nwb_dict = {v: k for k, v in bw_dict.items()}\r\npieces_dict = {**bw_dict, **wb_dict, 0: 0}  # 0は空きマス\r\npieces_list = list(pieces_dict.keys())\r\nreverse_piece_fn = np.vectorize(pieces_dict.get)\r\n"
  },
  {
    "path": "tools/download_and_build_lesserkai.sh",
    "content": "wget http://shogidokoro.starfree.jp/download/LesserkaiSrc.zip -P ./work_dirs\nunzip ./work_dirs/LesserkaiSrc.zip -d ./work_dirs/\nrm ./work_dirs/LesserkaiSrc.zip\ncd ./work_dirs/LesserkaiSrc/Lesserkai\nmake"
  },
  {
    "path": "tools/make_dataset.py",
    "content": "from pathlib import Path\n\nimport numpy as np\nfrom sklearn.model_selection import train_test_split\n\nfrom src.utils.hcpe import get_data_from_hcpe, load_hcpes\nfrom src.utils.sfen import get_data_from_sfen, get_gokaku_sfen_paths, load_sfens\n\n\ndef main():\n    gokaku_dir = Path('./data/ShogiGokakuKyokumen/')\n    hcpe_dir = Path('./data/hcpe/')\n    dataset_base_dir = Path('./data/dataset/')\n    dataset_base_dir.mkdir(exist_ok=True)\n\n    for num_of_moves in [40, 100]:\n        sfen_paths = get_gokaku_sfen_paths(gokaku_dir, num_of_moves)\n        sfens = load_sfens(sfen_paths)\n        data = get_data_from_sfen(sfens)\n        train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)\n\n        dataset_dir = dataset_base_dir / f'gokaku_{num_of_moves:03d}'\n        dataset_dir.mkdir()\n\n        np.save(dataset_dir / 'train.npy', train_data)\n        np.save(dataset_dir / 'val.npy', valid_data)\n\n    # selfplayの棋譜\n    hcpe_paths = sorted(hcpe_dir.glob('selfplay-*'))\n    hcpes = load_hcpes(hcpe_paths)\n    data = get_data_from_hcpe(hcpes)\n\n    train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)\n\n    dataset_dir = dataset_base_dir / 'selfplay'\n    dataset_dir.mkdir()\n\n    np.save(dataset_dir / 'train.npy', train_data)\n    np.save(dataset_dir / 'val.npy', valid_data)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/pl_to_transformers.py",
    "content": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nimport torch\n\nfrom src.model.bert import config\n\n\ndef argparse():\n    parser = ArgumentParser(description='Convert pl checkpoint to transformers format')\n    parser.add_argument('ckpt_path', type=str)\n    args, _ = parser.parse_known_args()\n    return args\n\n\ndef main(args):\n    ckpt_path = Path(args.ckpt_path)\n    ckpt_dir = ckpt_path.parent\n\n    ckpt = torch.load(ckpt_path)\n    state_dict = ckpt['state_dict']\n    state_dict = {'.'.join(k.split('.')[2:]): v for k, v in state_dict.items()}\n\n    # 同一ディレクトリにpytorch_model.binとconfig.jsonが必要\n    state_dict_path = ckpt_dir / f'pytorch_model.bin'\n    torch.save(state_dict, state_dict_path)\n    config.to_json_file(ckpt_dir / 'config.json')\n\n\nif __name__ == '__main__':\n    args = argparse()\n    main(args)\n"
  },
  {
    "path": "tools/test_engine.py",
    "content": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nfrom cshogi import cli\n\n\ndef parse_args():\n    parser = ArgumentParser()\n    parser.add_argument('engine_path')\n    return parser.parse_args()\n\n\ndef main(args):\n    benchmark = str(Path('./work_dirs/LesserkaiSrc/Lesserkai/Lesserkai').resolve())\n    engine = str((Path(args.engine_path)).resolve())\n\n    print('benchmark: ', benchmark)\n    print('engine: ', engine)\n\n    cli.main(benchmark, engine)\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "tools/train.py",
    "content": "from argparse import ArgumentParser\nfrom pathlib import Path\n\nimport pytorch_lightning as pl\nfrom omegaconf import OmegaConf\nfrom pytorch_lightning import loggers as pl_loggers\nfrom pytorch_lightning.callbacks import ModelCheckpoint\n\nfrom src.pl_modules import get_pl_modules\n\n\ndef argparse():\n    parser = ArgumentParser()\n    parser.add_argument('cfg', type=str)\n    parser.add_argument('--log_dir', type=str, default='./work_dirs')\n    parser.add_argument('--ckpt_path', type=str, default='')\n    parser = pl.Trainer.add_argparse_args(parser)\n    args, _ = parser.parse_known_args()\n    return args\n\n\ndef main(args):\n    cfg = OmegaConf.load(args.cfg)\n    pl.seed_everything(cfg.seed)\n\n    # configのtrain_paramsをargsに反映\n    for k, v in cfg.train_params.items():\n        setattr(args, k, v)\n\n    # Disable default checkpoint callback\n    args.checkpoint_callback = False\n    trainer = pl.Trainer.from_argparse_args(args)\n    trainer.logger = pl_loggers.TensorBoardLogger(save_dir=args.log_dir, name=Path(args.cfg).stem, default_hp_metric=False)\n    trainer.callbacks.append(ModelCheckpoint(filename='{step:07d}-{val_loss:.2f}', monitor='val_loss',\n                                             save_top_k=1, save_last=True))\n\n    model, data = get_pl_modules(cfg)\n    if args.ckpt_path:\n        model = model.load_from_checkpoint(args.ckpt_path)\n    trainer.fit(model, datamodule=data)\n\n\nif __name__ == '__main__':\n    args = argparse()\n    main(args)\n"
  }
]